raja5259 commited on
Commit
67e4e9f
·
verified ·
1 Parent(s): 4ff1c2a

update RohanGivenCode.py with max_len i.e., new_tokens

Browse files
Files changed (1) hide show
  1. RohanGivenCode.py +300 -300
RohanGivenCode.py CHANGED
@@ -1,300 +1,300 @@
1
-
2
- import os
3
- import math
4
- import time
5
- import inspect
6
- from dataclasses import dataclass
7
- import torch
8
- import torch.nn as nn
9
- from torch.nn import functional as F
10
- import tiktoken
11
-
12
- #1 --- Seema start here
13
- class CausalSelfAttention(nn.Module):
14
-
15
- def __init__(self, config):
16
- super().__init__()
17
- assert config.n_embd % config.n_head == 0
18
- # key, query, value projections for all heads, but in a batch
19
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
20
- # output projection
21
- self.c_proj = nn.Linear(config.n_embd, config.n_embd)
22
- self.c_proj.NANGPT_SCALE_INIT = 1
23
- # regularization
24
- self.n_head = config.n_head
25
- self.n_embd = config.n_embd
26
- self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
27
-
28
- def forward(self, x):
29
- B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
30
- # calculate query, key, values for all heads in batch and move head forward to be the batch dim
31
- # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
32
- # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
33
- qkv = self.c_attn(x)
34
- q, k, v = qkv.split(self.n_embd, dim=2)
35
- k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
36
- q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
37
- v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
38
-
39
- # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
40
- # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
41
- # att = F.softmax(att, dim=-1)
42
- # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
43
-
44
- y = F.scaled_dot_product_attention(q, k, v, is_causal = True) # Flash attention
45
-
46
- y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
47
- # output projection
48
- y = self.c_proj(y)
49
- return y
50
-
51
-
52
- class MLP(nn.Module):
53
-
54
- def __init__(self, config):
55
- super().__init__()
56
- self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
57
- self.gelu = nn.GELU(approximate='tanh')
58
- self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
59
- self.c_proj.NANOGPT_SCALE_INIT = 1
60
-
61
- def forward(self, x):
62
- x = self.c_fc(x)
63
- x = self.gelu(x)
64
- x = self.c_proj(x)
65
- return x
66
-
67
- class Block(nn.Module):
68
-
69
- def __init__(self, config):
70
- super().__init__()
71
- self.ln_1 = nn.LayerNorm(config.n_embd)
72
- self.attn = CausalSelfAttention(config)
73
- self.ln_2 = nn.LayerNorm(config.n_embd)
74
- self.mlp = MLP(config)
75
-
76
- def forward(self, x):
77
- x = x + self.attn(self.ln_1(x))
78
- x = x + self.mlp(self.ln_2(x))
79
- return x
80
-
81
-
82
- @dataclass
83
- class GPTConfig:
84
- block_size: int = 1024 # max sequence length
85
- vocab_size: int = 50304 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
86
- n_layer: int = 12 # number of layers
87
- n_head: int = 12 # number of heads
88
- n_embd: int = 768 # embedding dimension
89
-
90
-
91
-
92
- class GPT(nn.Module):
93
-
94
- def __init__(self, config):
95
- super().__init__()
96
- self.config = config
97
-
98
- self.transformer = nn.ModuleDict(dict(
99
- wte = nn.Embedding(config.vocab_size, config.n_embd),
100
- wpe = nn.Embedding(config.block_size, config.n_embd),
101
- h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
102
- ln_f = nn.LayerNorm(config.n_embd),
103
- ))
104
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
105
-
106
- # weight sharing
107
- self.transformer.wte.weight = self.lm_head.weight
108
-
109
- # weight initialization
110
- self.apply(self._init_weights)
111
-
112
- def _init_weights(self, module):
113
- if isinstance(module, nn.Linear):
114
- std = 0.02
115
- if hasattr(module, 'NANGPT_SCALE_INIT'):
116
- std *= (2 * self.config.n_layer) ** -0.5
117
- torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
118
- if module.bias is not None:
119
- torch.nn.init.zeros_(module.bias)
120
- elif isinstance(module, nn.Embedding):
121
- torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
122
-
123
- #1 --- Seema end here
124
-
125
- #============================================================================================================
126
-
127
- #2 --- Raja start here
128
- def forward(self, idx, targets=None):
129
- # idx is of shape (B, T)
130
- B, T = idx.size()
131
- assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
132
- # forward the token and posisition embeddings
133
- pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
134
- pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
135
- tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
136
- x = tok_emb + pos_emb
137
- # forward the blocks of the transformer
138
- for block in self.transformer.h:
139
- x = block(x)
140
- # forward the final layernorm and the classifier
141
- x = self.transformer.ln_f(x)
142
- logits = self.lm_head(x) # (B, T, vocab_size)
143
- loss = None
144
- if targets is not None:
145
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
146
- return logits, loss
147
-
148
- def configure_optimizers(self, weight_decay, learning_rate, device_type):
149
- # start with all of the candidate parameters (that require grad)
150
- param_dict = {pn: p for pn, p in self.named_parameters()}
151
- param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
152
- # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
153
- # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
154
- decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
155
- nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
156
- optim_groups = [
157
- {'params': decay_params, 'weight_decay': weight_decay},
158
- {'params': nodecay_params, 'weight_decay': 0.0}
159
- ]
160
- num_decay_params = sum(p.numel() for p in decay_params)
161
- num_nodecay_params = sum(p.numel() for p in nodecay_params)
162
-
163
- print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
164
- print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
165
- # Create AdamW optimizer and use the fused version if it is available
166
- fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
167
- use_fused = fused_available and device_type == "cuda"
168
-
169
- print(f"using fused AdamW: {use_fused}")
170
- optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
171
- return optimizer
172
-
173
- # model = GPT.from_pretrained('gpt2')
174
-
175
-
176
- #2 --- Raja end here
177
-
178
-
179
- #============================================================================================================
180
-
181
- #3 --- Yasaswini start here
182
- class DataLoaderLite:
183
- def __init__(self, B, T, text_input):
184
- self.B = B
185
- self.T = T
186
-
187
- self.enc = tiktoken.get_encoding('gpt2')
188
- tokens = self.enc.encode(text_input)
189
- self.tokens = torch.tensor(tokens)
190
- print(f'loaded {len(self.tokens)} tokens')
191
- print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
192
-
193
- # state
194
- self.current_position = 0
195
-
196
- def next_batch(self):
197
- B, T = self.B, self.T
198
- buf = self.tokens[self.current_position: self.current_position + B * T + 1]
199
- x = (buf[:-1]).view(B, T) # inputs
200
- y = (buf[1:]).view(B, T) # targets
201
- # advance the position in the tensor
202
- self.current_position += B*T
203
- # if loading the next batch would be out of bounds, reset
204
- if self.current_position + (B * T + 1) > len(self.tokens):
205
- self.current_position = 0
206
- return x, y
207
-
208
- def get_model(device):
209
- # CHANGES IN CURRENT CODE
210
- torch.set_float32_matmul_precision('high')
211
- model = GPT(GPTConfig())
212
- model.to(device)
213
- # model = torch.compile(model)
214
- return model
215
-
216
-
217
-
218
-
219
- def get_lr(it):
220
- # CODE UPDATE HERE
221
- # warmup_steps = 10
222
- # max_steps = 50
223
- warmup_steps = 100
224
-
225
- max_lr = 6e-4
226
- min_lr = max_lr * 0.1
227
- if it < warmup_steps:
228
- return max_lr * (it + 1) / warmup_steps
229
- if it > max_steps:
230
- return min_lr
231
- decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
232
- assert 0 <= decay_ratio <=1
233
- coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
234
- return min_lr + coeff * (max_lr - min_lr)
235
-
236
-
237
-
238
- # optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-4, betas=(0.9, 0.95), eps=1e-8)
239
- def train_the_model(train_loader):
240
- model = get_model(device)
241
- optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device)
242
- for step in range(max_steps):
243
- t0 = time.time()
244
- x, y = train_loader.next_batch()
245
- x, y = x.to(device), y.to(device)
246
- optimizer.zero_grad()
247
- # NEW CODE ADDED HERE
248
- with torch.autocast(device_type=device, dtype=torch.bfloat16):
249
- logits, loss = model(x, y)
250
- loss.backward()
251
- norm = torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)
252
- # NEW CODE
253
- lr = get_lr(step)
254
- for param_group in optimizer.param_groups:
255
- param_group['lr'] = lr
256
-
257
- optimizer.step()
258
- torch.cuda.synchronize()
259
- t1 = time.time()
260
- dt = (t1 - t0) * 1000
261
- tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
262
- print(f'step{step} | loss: {loss.item()} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec: .2f} | norm: {norm:.2f}')
263
- return model, loss
264
-
265
-
266
- #From here inference
267
- def infer_the_model(device, test_loader, save1_or_load0):
268
- x, y = test_loader.next_batch()
269
- model = get_model(device)
270
- if save1_or_load0 == 0:
271
- model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device(device)))
272
- torch.manual_seed(42)
273
- torch.cuda.manual_seed(42)
274
- while x.size(1) < max_length:
275
- # forward the model to get the logits
276
- with torch.no_grad():
277
- logits = model(x)[0] # (B, T, vocab_size)
278
- # take the logits at the last position
279
- logits = logits[:, -1, :] # (B, vocab_size)
280
- # get the probabilities
281
- probs = F.softmax(logits, dim=-1)
282
- # do top-k sampling of 50 (huggingface pipeline default)
283
- # topk_probs here becomes (5, 50), topk_indices is (5, 50)
284
- topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
285
- # select a token from the top-k probabilities
286
- # note: multinomial does not demand the input to sum to 1
287
- ix = torch.multinomial(topk_probs, 1) # (B, 1)
288
- # gather the corresponding indices
289
- xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
290
- # append to the sequence
291
- x = torch.cat((x, xcol), dim=1)
292
-
293
- # print the generated text
294
- retval = ""
295
- for i in range(num_return_sequences):
296
- tokens = x[i, :max_length].tolist()
297
- decoded = test_loader.enc.decode(tokens)
298
- print(">", decoded)
299
- retval += decoded
300
- return retval
 
1
+
2
+ import os
3
+ import math
4
+ import time
5
+ import inspect
6
+ from dataclasses import dataclass
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ import tiktoken
11
+
12
+ #1 --- Seema start here
13
+ class CausalSelfAttention(nn.Module):
14
+
15
+ def __init__(self, config):
16
+ super().__init__()
17
+ assert config.n_embd % config.n_head == 0
18
+ # key, query, value projections for all heads, but in a batch
19
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
20
+ # output projection
21
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
22
+ self.c_proj.NANGPT_SCALE_INIT = 1
23
+ # regularization
24
+ self.n_head = config.n_head
25
+ self.n_embd = config.n_embd
26
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
27
+
28
+ def forward(self, x):
29
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
30
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
31
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
32
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
33
+ qkv = self.c_attn(x)
34
+ q, k, v = qkv.split(self.n_embd, dim=2)
35
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
36
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
37
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
38
+
39
+ # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
40
+ # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
41
+ # att = F.softmax(att, dim=-1)
42
+ # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
43
+
44
+ y = F.scaled_dot_product_attention(q, k, v, is_causal = True) # Flash attention
45
+
46
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
47
+ # output projection
48
+ y = self.c_proj(y)
49
+ return y
50
+
51
+
52
+ class MLP(nn.Module):
53
+
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
57
+ self.gelu = nn.GELU(approximate='tanh')
58
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
59
+ self.c_proj.NANOGPT_SCALE_INIT = 1
60
+
61
+ def forward(self, x):
62
+ x = self.c_fc(x)
63
+ x = self.gelu(x)
64
+ x = self.c_proj(x)
65
+ return x
66
+
67
+ class Block(nn.Module):
68
+
69
+ def __init__(self, config):
70
+ super().__init__()
71
+ self.ln_1 = nn.LayerNorm(config.n_embd)
72
+ self.attn = CausalSelfAttention(config)
73
+ self.ln_2 = nn.LayerNorm(config.n_embd)
74
+ self.mlp = MLP(config)
75
+
76
+ def forward(self, x):
77
+ x = x + self.attn(self.ln_1(x))
78
+ x = x + self.mlp(self.ln_2(x))
79
+ return x
80
+
81
+
82
+ @dataclass
83
+ class GPTConfig:
84
+ block_size: int = 1024 # max sequence length
85
+ vocab_size: int = 50304 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
86
+ n_layer: int = 12 # number of layers
87
+ n_head: int = 12 # number of heads
88
+ n_embd: int = 768 # embedding dimension
89
+
90
+
91
+
92
+ class GPT(nn.Module):
93
+
94
+ def __init__(self, config):
95
+ super().__init__()
96
+ self.config = config
97
+
98
+ self.transformer = nn.ModuleDict(dict(
99
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
100
+ wpe = nn.Embedding(config.block_size, config.n_embd),
101
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
102
+ ln_f = nn.LayerNorm(config.n_embd),
103
+ ))
104
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
105
+
106
+ # weight sharing
107
+ self.transformer.wte.weight = self.lm_head.weight
108
+
109
+ # weight initialization
110
+ self.apply(self._init_weights)
111
+
112
+ def _init_weights(self, module):
113
+ if isinstance(module, nn.Linear):
114
+ std = 0.02
115
+ if hasattr(module, 'NANGPT_SCALE_INIT'):
116
+ std *= (2 * self.config.n_layer) ** -0.5
117
+ torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
118
+ if module.bias is not None:
119
+ torch.nn.init.zeros_(module.bias)
120
+ elif isinstance(module, nn.Embedding):
121
+ torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
122
+
123
+ #1 --- Seema end here
124
+
125
+ #============================================================================================================
126
+
127
+ #2 --- Raja start here
128
+ def forward(self, idx, targets=None):
129
+ # idx is of shape (B, T)
130
+ B, T = idx.size()
131
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
132
+ # forward the token and posisition embeddings
133
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
134
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
135
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
136
+ x = tok_emb + pos_emb
137
+ # forward the blocks of the transformer
138
+ for block in self.transformer.h:
139
+ x = block(x)
140
+ # forward the final layernorm and the classifier
141
+ x = self.transformer.ln_f(x)
142
+ logits = self.lm_head(x) # (B, T, vocab_size)
143
+ loss = None
144
+ if targets is not None:
145
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
146
+ return logits, loss
147
+
148
+ def configure_optimizers(self, weight_decay, learning_rate, device_type):
149
+ # start with all of the candidate parameters (that require grad)
150
+ param_dict = {pn: p for pn, p in self.named_parameters()}
151
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
152
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
153
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
154
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
155
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
156
+ optim_groups = [
157
+ {'params': decay_params, 'weight_decay': weight_decay},
158
+ {'params': nodecay_params, 'weight_decay': 0.0}
159
+ ]
160
+ num_decay_params = sum(p.numel() for p in decay_params)
161
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
162
+
163
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
164
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
165
+ # Create AdamW optimizer and use the fused version if it is available
166
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
167
+ use_fused = fused_available and device_type == "cuda"
168
+
169
+ print(f"using fused AdamW: {use_fused}")
170
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
171
+ return optimizer
172
+
173
+ # model = GPT.from_pretrained('gpt2')
174
+
175
+
176
+ #2 --- Raja end here
177
+
178
+
179
+ #============================================================================================================
180
+
181
+ #3 --- Yasaswini start here
182
+ class DataLoaderLite:
183
+ def __init__(self, B, T, text_input):
184
+ self.B = B
185
+ self.T = T
186
+
187
+ self.enc = tiktoken.get_encoding('gpt2')
188
+ tokens = self.enc.encode(text_input)
189
+ self.tokens = torch.tensor(tokens)
190
+ print(f'loaded {len(self.tokens)} tokens')
191
+ print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
192
+
193
+ # state
194
+ self.current_position = 0
195
+
196
+ def next_batch(self):
197
+ B, T = self.B, self.T
198
+ buf = self.tokens[self.current_position: self.current_position + B * T + 1]
199
+ x = (buf[:-1]).view(B, T) # inputs
200
+ y = (buf[1:]).view(B, T) # targets
201
+ # advance the position in the tensor
202
+ self.current_position += B*T
203
+ # if loading the next batch would be out of bounds, reset
204
+ if self.current_position + (B * T + 1) > len(self.tokens):
205
+ self.current_position = 0
206
+ return x, y
207
+
208
+ def get_model(device):
209
+ # CHANGES IN CURRENT CODE
210
+ torch.set_float32_matmul_precision('high')
211
+ model = GPT(GPTConfig())
212
+ model.to(device)
213
+ # model = torch.compile(model)
214
+ return model
215
+
216
+
217
+
218
+
219
+ def get_lr(it):
220
+ # CODE UPDATE HERE
221
+ # warmup_steps = 10
222
+ # max_steps = 50
223
+ warmup_steps = 100
224
+
225
+ max_lr = 6e-4
226
+ min_lr = max_lr * 0.1
227
+ if it < warmup_steps:
228
+ return max_lr * (it + 1) / warmup_steps
229
+ if it > max_steps:
230
+ return min_lr
231
+ decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
232
+ assert 0 <= decay_ratio <=1
233
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
234
+ return min_lr + coeff * (max_lr - min_lr)
235
+
236
+
237
+
238
+ # optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-4, betas=(0.9, 0.95), eps=1e-8)
239
+ def train_the_model(train_loader):
240
+ model = get_model(device)
241
+ optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device)
242
+ for step in range(max_steps):
243
+ t0 = time.time()
244
+ x, y = train_loader.next_batch()
245
+ x, y = x.to(device), y.to(device)
246
+ optimizer.zero_grad()
247
+ # NEW CODE ADDED HERE
248
+ with torch.autocast(device_type=device, dtype=torch.bfloat16):
249
+ logits, loss = model(x, y)
250
+ loss.backward()
251
+ norm = torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)
252
+ # NEW CODE
253
+ lr = get_lr(step)
254
+ for param_group in optimizer.param_groups:
255
+ param_group['lr'] = lr
256
+
257
+ optimizer.step()
258
+ torch.cuda.synchronize()
259
+ t1 = time.time()
260
+ dt = (t1 - t0) * 1000
261
+ tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
262
+ print(f'step{step} | loss: {loss.item()} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec: .2f} | norm: {norm:.2f}')
263
+ return model, loss
264
+
265
+
266
+ #From here inference
267
+ def infer_the_model(device, test_loader, save1_or_load0, max_length):
268
+ x, y = test_loader.next_batch()
269
+ model = get_model(device)
270
+ if save1_or_load0 == 0:
271
+ model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device(device)))
272
+ torch.manual_seed(42)
273
+ torch.cuda.manual_seed(42)
274
+ while x.size(1) < max_length:
275
+ # forward the model to get the logits
276
+ with torch.no_grad():
277
+ logits = model(x)[0] # (B, T, vocab_size)
278
+ # take the logits at the last position
279
+ logits = logits[:, -1, :] # (B, vocab_size)
280
+ # get the probabilities
281
+ probs = F.softmax(logits, dim=-1)
282
+ # do top-k sampling of 50 (huggingface pipeline default)
283
+ # topk_probs here becomes (5, 50), topk_indices is (5, 50)
284
+ topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
285
+ # select a token from the top-k probabilities
286
+ # note: multinomial does not demand the input to sum to 1
287
+ ix = torch.multinomial(topk_probs, 1) # (B, 1)
288
+ # gather the corresponding indices
289
+ xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
290
+ # append to the sequence
291
+ x = torch.cat((x, xcol), dim=1)
292
+
293
+ # print the generated text
294
+ retval = ""
295
+ for i in range(num_return_sequences):
296
+ tokens = x[i, :max_length].tolist()
297
+ decoded = test_loader.enc.decode(tokens)
298
+ print(">", decoded)
299
+ retval += decoded
300
+ return retval