zwimpee commited on
Commit
a7ef49b
1 Parent(s): 659a118

uploading preprocessing script, model code, and training script

Browse files
Files changed (3) hide show
  1. model.py +365 -0
  2. preprocessing.py +113 -0
  3. train.py +260 -0
model.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #.\experiments\experiment1\model.py
2
+ import logging
3
+
4
+ logging.basicConfig(level=logging.DEBUG)
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from prereqs.nanoGPT.model import GPTConfig, GPT, MLP
13
+
14
+ # set up logger
15
+ logger = logging.getLogger(__name__)
16
+ logger.setLevel(logging.DEBUG)
17
+
18
+ def new_rielu(x):
19
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
20
+
21
+ @dataclass
22
+ class RotationallyInvariantGPTConfig:
23
+ block_size: int = 512
24
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
25
+ n_layer: int = 6
26
+ n_head: int = 8
27
+ n_embd: int = 768
28
+ dropout: float = 0.0
29
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
30
+ rotational_invariance: bool = True # Set to True to enable the rotationally invariant gate layers
31
+
32
+ # Models
33
+ class RotationInvariantLayerNorm(nn.Module):
34
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
35
+ def __init__(self, ndim, bias):
36
+ super().__init__()
37
+ self.weight = nn.Parameter(torch.ones(ndim))
38
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
39
+ self.rotation_gate = nn.Linear(ndim, ndim, bias=False) # no bias needed for rotation
40
+ self.rotation_gate.weight.data = torch.eye(ndim)
41
+
42
+ def forward(self, input, rotation_matrix=None):
43
+ # apply rotation
44
+ if rotation_matrix is not None:
45
+ input = torch.matmul(input, self.rotation_gate(rotation_matrix))
46
+
47
+ # normalize
48
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
49
+
50
+ class RotationallyInvariantAttention(nn.Module):
51
+ def __init__(self, config):
52
+ super().__init__()
53
+ assert config.n_embd % config.n_head == 0
54
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
55
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
56
+ self.attn_dropout = nn.Dropout(config.dropout)
57
+ self.resid_dropout = nn.Dropout(config.dropout)
58
+ self.n_head = config.n_head
59
+ self.n_embd = config.n_embd
60
+ self.dropout = config.dropout
61
+ self.gate_q = nn.Linear(config.n_embd // config.n_head, 1, bias=config.bias)
62
+ self.gate_k = nn.Linear(config.n_embd // config.n_head, 1, bias=config.bias)
63
+
64
+ def forward(self, x, rotation_matrix=None):
65
+ logging.debug(f'x.size(): {x.size()}')
66
+
67
+ B, T, C = x.size()
68
+
69
+ logging.debug(f'B: {B}, T: {T}, C: {C}')
70
+
71
+ q, k, v = self.c_attn(x).chunk(3, dim=-1)
72
+
73
+ logging.debug('Pre-Reshape Q, K, and V')
74
+ logging.debug(f'q.size(): {q.size()}, k.size(): {k.size()}, v.size(): {v.size()}')
75
+ logging.debug(f'q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}')
76
+
77
+
78
+ # Reshape q and k to match the shape of att_dotproduct and att_rotation
79
+ q = q.view(B, T, self.n_head, C // self.n_head).permute(0, 2, 1, 3)
80
+ k = k.view(B, T, self.n_head, C // self.n_head).permute(0, 2, 1, 3)
81
+ v = v.view(B, T, self.n_head, C // self.n_head).permute(0, 2, 1, 3)
82
+
83
+ logging.debug('Post-Reshape Q, K, and V')
84
+ logging.debug(f'q.size(): {q.size()}, k.size(): {k.size()}, v.size(): {v.size()}')
85
+ logging.debug(f'q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}')
86
+
87
+ # Compute gate_q and gate_k such that they have the same shape as q and k
88
+ gate_q = torch.sigmoid(self.gate_q(q.view(B, self.n_head, T, -1)))
89
+ gate_k = torch.sigmoid(self.gate_k(k.view(B, self.n_head, T, -1)))
90
+
91
+ # Traditional dot-product attention
92
+ qk_dot = q @ k.transpose(-2, -1)
93
+ att_dotproduct = qk_dot / math.sqrt(self.n_embd)
94
+
95
+ # Rotation invariant attention
96
+ q_norm = torch.sum(q * q, dim=-1, keepdim=True)
97
+ k_norm = torch.sum(k * k, dim=-1, keepdim=True)
98
+ distances = q_norm + k_norm.transpose(-2, -1) - 2 * qk_dot
99
+ att_rotation = -torch.sqrt(distances)
100
+ att_rotation = att_rotation / math.sqrt(self.n_embd)
101
+
102
+ # Apply gating to attention scores
103
+ mixed_att = att_dotproduct * gate_q + att_rotation * (torch.ones_like(gate_q) - gate_q)
104
+ att_scores = mixed_att / gate_k
105
+
106
+ if rotation_matrix is not None:
107
+ att_scores = att_scores + rotation_matrix
108
+
109
+ att_weights = F.softmax(att_scores, dim=-1)
110
+ y = att_weights @ v
111
+ y = y.permute(0, 2, 1, 3).contiguous().view(B, T, C)
112
+
113
+ y = self.resid_dropout(self.c_proj(y))
114
+ return y
115
+
116
+ class RotationallyInvariantMLP(nn.Module):
117
+
118
+ def __init__(self, config):
119
+ super().__init__()
120
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
121
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
122
+ self.dropout = nn.Dropout(config.dropout)
123
+ self.rotation_gate = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) # Added rotational gate layer
124
+ self.rotation_gate.weight.data = torch.eye(config.n_embd) # Assuming initial rotation matrix as an identity matrix
125
+
126
+ def forward(self, x, rotation_matrix=None):
127
+ x = self.c_fc(x)
128
+ x = F.gelu(x)
129
+ x = self.c_proj(x)
130
+ x = self.dropout(x)
131
+
132
+ # Rotational Invariance Part
133
+ if rotation_matrix is not None:
134
+ x = torch.matmul(x, self.rotation_gate(rotation_matrix))
135
+
136
+ return x
137
+
138
+ class RotationallyInvariantBlock(nn.Module):
139
+
140
+ def __init__(self, config):
141
+ super().__init__()
142
+ self.ln_1 = RotationInvariantLayerNorm(config.n_embd, bias=config.bias)
143
+ self.attn = RotationallyInvariantAttention(config)
144
+ self.ln_2 = RotationInvariantLayerNorm(config.n_embd, bias=config.bias)
145
+ self.mlp = RotationallyInvariantMLP(config)
146
+
147
+ def forward(self, x, rotation_matrix=None):
148
+ x = x + self.attn(self.ln_1(x), rotation_matrix)
149
+ x = x + self.mlp(self.ln_2(x), rotation_matrix)
150
+ return x
151
+
152
+ class RotationallyInvariantGPT(nn.Module):
153
+
154
+ def __init__(self, config):
155
+ super().__init__()
156
+ assert config.vocab_size is not None
157
+ assert config.block_size is not None
158
+ self.config = config
159
+
160
+ self.transformer = nn.ModuleDict(dict(
161
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
162
+ wpe = nn.Embedding(config.block_size, config.n_embd),
163
+ drop = nn.Dropout(config.dropout),
164
+ h = nn.ModuleList([RotationallyInvariantBlock(config) for _ in range(config.n_layer)]),
165
+ ln_f = RotationInvariantLayerNorm(config.n_embd, bias=config.bias),
166
+ ))
167
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
168
+ # with weight tying when using torch.compile() some warnings get generated:
169
+ # "UserWarning: functional_call was passed multiple values for tied weights.
170
+ # This behavior is deprecated and will be an error in future versions"
171
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
172
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
173
+
174
+ # init all weights
175
+ self.apply(self._init_weights)
176
+ # apply special scaled init to the residual projections, per GPT-2 paper
177
+ for pn, p in self.named_parameters():
178
+ if pn.endswith('c_proj.weight'):
179
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
180
+
181
+ # report number of parameters
182
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
183
+
184
+ def get_num_params(self, non_embedding=True):
185
+ """
186
+ Return the number of parameters in the model.
187
+ For non-embedding count (default), the position embeddings get subtracted.
188
+ The token embeddings would too, except due to the parameter sharing these
189
+ params are actually used as weights in the final layer, so we include them.
190
+ """
191
+ n_params = sum(p.numel() for p in self.parameters())
192
+ if non_embedding:
193
+ n_params -= self.transformer.wpe.weight.numel()
194
+ return n_params
195
+
196
+ def _init_weights(self, module):
197
+ if isinstance(module, nn.Linear):
198
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
199
+ if module.bias is not None:
200
+ torch.nn.init.zeros_(module.bias)
201
+ elif isinstance(module, nn.Embedding):
202
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
203
+
204
+ def forward(self, idx, targets=None):
205
+ device = idx.device
206
+ b, t = idx.size()
207
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
208
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
209
+
210
+ # forward the GPT model itself
211
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
212
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
213
+ x = self.transformer.drop(tok_emb + pos_emb)
214
+ for block in self.transformer.h:
215
+ x = block(x)
216
+ x = self.transformer.ln_f(x)
217
+
218
+ if targets is not None:
219
+ # if we are given some desired targets also calculate the loss
220
+ logits = self.lm_head(x)
221
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
222
+ else:
223
+ # inference-time mini-optimization: only forward the lm_head on the very last position
224
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
225
+ loss = None
226
+
227
+ return logits, loss
228
+
229
+ def crop_block_size(self, block_size):
230
+ # model surgery to decrease the block size if necessary
231
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
232
+ # but want to use a smaller block size for some smaller, simpler model
233
+ assert block_size <= self.config.block_size
234
+ self.config.block_size = block_size
235
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
236
+ for block in self.transformer.h:
237
+ if hasattr(block.attn, 'bias'):
238
+ block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
239
+
240
+ @classmethod
241
+ def from_pretrained(cls, model_type, override_args=None):
242
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
243
+ override_args = override_args or {} # default to empty dict
244
+ # only dropout can be overridden see more notes below
245
+ assert all(k == 'dropout' for k in override_args)
246
+ from transformers import GPT2LMHeadModel
247
+ print("loading weights from pretrained gpt: %s" % model_type)
248
+
249
+ # n_layer, n_head and n_embd are determined from model_type
250
+ config_args = {
251
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
252
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
253
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
254
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
255
+ }[model_type]
256
+ print("forcing vocab_size=50257, block_size=1024, bias=True")
257
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
258
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
259
+ config_args['bias'] = True # always True for GPT model checkpoints
260
+ # we can override the dropout rate, if desired
261
+ if 'dropout' in override_args:
262
+ print(f"overriding dropout rate to {override_args['dropout']}")
263
+ config_args['dropout'] = override_args['dropout']
264
+ # create a from-scratch initialized minGPT model
265
+ config = GPTConfig(**config_args)
266
+ model = GPT(config)
267
+ sd = model.state_dict()
268
+ sd_keys = sd.keys()
269
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
270
+
271
+ # init a huggingface/transformers model
272
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
273
+ sd_hf = model_hf.state_dict()
274
+
275
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
276
+ sd_keys_hf = sd_hf.keys()
277
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
278
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
279
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
280
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
281
+ # this means that we have to transpose these weights when we import them
282
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
283
+ for k in sd_keys_hf:
284
+ if any(k.endswith(w) for w in transposed):
285
+ # special treatment for the Conv1D weights we need to transpose
286
+ assert sd_hf[k].shape[::-1] == sd[k].shape
287
+ with torch.no_grad():
288
+ sd[k].copy_(sd_hf[k].t())
289
+ else:
290
+ # vanilla copy over the other parameters
291
+ assert sd_hf[k].shape == sd[k].shape
292
+ with torch.no_grad():
293
+ sd[k].copy_(sd_hf[k])
294
+
295
+ return model
296
+
297
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
298
+ # start with all of the candidate parameters
299
+ param_dict = {pn: p for pn, p in self.named_parameters()}
300
+ # filter out those that do not require grad
301
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
302
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
303
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
304
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
305
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
306
+ optim_groups = [
307
+ {'params': decay_params, 'weight_decay': weight_decay},
308
+ {'params': nodecay_params, 'weight_decay': 0.0}
309
+ ]
310
+ num_decay_params = sum(p.numel() for p in decay_params)
311
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
312
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
313
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
314
+ # Create AdamW optimizer and use the fused version if it is available
315
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
316
+ use_fused = fused_available and device_type == 'cuda'
317
+ extra_args = dict(fused=True) if use_fused else dict()
318
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
319
+ print(f"using fused AdamW: {use_fused}")
320
+
321
+ return optimizer
322
+
323
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
324
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
325
+ # first estimate the number of flops we do per iteration.
326
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
327
+ N = self.get_num_params()
328
+ cfg = self.config
329
+ L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
330
+ flops_per_token = 6*N + 12*L*H*Q*T
331
+ flops_per_fwdbwd = flops_per_token * T
332
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
333
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
334
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
335
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
336
+ mfu = flops_achieved / flops_promised
337
+ return mfu
338
+
339
+ @torch.no_grad()
340
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
341
+ """
342
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
343
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
344
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
345
+ """
346
+ for _ in range(max_new_tokens):
347
+ # if the sequence context is growing too long we must crop it at block_size
348
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
349
+ # forward the model to get the logits for the index in the sequence
350
+ logits, _ = self(idx_cond)
351
+ # pluck the logits at the final step and scale by desired temperature
352
+ logits = logits[:, -1, :] / temperature
353
+ # optionally crop the logits to only the top k options
354
+ if top_k is not None:
355
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
356
+ logits[logits < v[:, [-1]]] = -float('Inf')
357
+ # apply softmax to convert logits to (normalized) probabilities
358
+ probs = F.softmax(logits, dim=-1)
359
+ # sample from the distribution
360
+ idx_next = torch.multinomial(probs, num_samples=1)
361
+ # append sampled index to the running sequence and continue
362
+ idx = torch.cat((idx, idx_next), dim=1)
363
+
364
+ return idx
365
+
preprocessing.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #./experiments/experiment1/preprocessing.py
2
+ import logging
3
+ import os
4
+ import sqlite3
5
+ from transformers import GPT2TokenizerFast
6
+ from datasets import load_dataset
7
+
8
+ class DatabaseInterface(object):
9
+ def __init__(self, db_file):
10
+ self.db_file = db_file
11
+
12
+ def create_table(self, table_name=None):
13
+ conn = sqlite3.connect(self.db_file)
14
+ c = conn.cursor()
15
+ c.execute(
16
+ '''
17
+ CREATE TABLE IF NOT EXISTS plain_text (
18
+ text TEXT,
19
+ split TEXT
20
+ )
21
+ '''
22
+ )
23
+ conn.commit()
24
+ conn.close()
25
+
26
+ def write_plain_text(self, example, split):
27
+ conn = sqlite3.connect(self.db_file)
28
+ c = conn.cursor()
29
+ c.execute("INSERT INTO plain_text (text, split) VALUES (?, ?)",
30
+ (example, split))
31
+ conn.commit()
32
+ conn.close()
33
+
34
+
35
+ def process_and_write(example, writer, split):
36
+ writer.write_plain_text(example, split)
37
+
38
+
39
+ def prepare_data(start_index, end_index, **kwargs):
40
+ data_writer = kwargs['data_writer']
41
+ train_dataset = kwargs['train_dataset']
42
+ val_dataset = kwargs['val_dataset']
43
+
44
+ for split, dataset in {'val': val_dataset, 'train': train_dataset}.items():
45
+ subset = dataset[start_index:end_index] # Select the subset based on start and end indices
46
+
47
+ if isinstance(subset, dict):
48
+ subset = subset["text"] # Extract the "text" part from the subset dictionary
49
+
50
+ for example in subset:
51
+ process_and_write(example, data_writer, split)
52
+
53
+
54
+ if __name__ == '__main__':
55
+ logging.basicConfig(
56
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
57
+ datefmt='%m/%d/%Y %H:%M:%S',
58
+ level=logging.INFO
59
+ )
60
+
61
+ # Configs
62
+ batch_size = 32
63
+ num_processes = 4 # number of jobs to run simultaneously
64
+
65
+ logging.info("Creating Database Interface")
66
+ db_file_path = os.path.join('data', 'experiment1.db')
67
+
68
+ _delete_db = True
69
+
70
+ # Check to see if the database file already exists
71
+ if os.path.exists(db_file_path):
72
+ if _delete_db:
73
+ logging.info(f"Database file {db_file_path} already exists. Deleting it.")
74
+ os.remove(db_file_path)
75
+ data_writer = DatabaseInterface(db_file_path)
76
+ data_writer.create_table()
77
+ logging.info("Database table `plain_text` created")
78
+ else:
79
+ logging.info(f"Database file {db_file_path} already exists. Connecting to it.")
80
+ data_writer = DatabaseInterface(db_file_path)
81
+ else:
82
+ data_writer = DatabaseInterface(db_file_path)
83
+ data_writer.create_table()
84
+ logging.info("Database table `plain_text` created")
85
+
86
+ #cache_dir=os.path.join(
87
+ # 'C:/Users/User/.cache/huggingface/datasets/openwebtext/plain_text',
88
+ # '1.0.0',
89
+ # '6f68e85c16ccc770c0dd489f4008852ea9633604995addd0cd76e293aed9e521'
90
+ #)
91
+
92
+ dataset = load_dataset(
93
+ "openwebtext",
94
+ cache_dir=cache_dir,
95
+ num_proc=num_processes,
96
+ save_infos = True,
97
+ writer_batch_size=batch_size
98
+
99
+ )
100
+
101
+ split_dataset = dataset["train"].train_test_split(test_size=0.1, seed=42, shuffle=False)
102
+ train_dataset = split_dataset["train"]
103
+ val_dataset = split_dataset["test"]
104
+
105
+ prepare_data(
106
+ start_index=0,
107
+ end_index=1000,
108
+ **{
109
+ 'data_writer': data_writer,
110
+ 'train_dataset': train_dataset,
111
+ 'val_dataset': val_dataset,
112
+ }
113
+ )
train.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #./experiments/experiment1/train.py
2
+ import logging
3
+ import pickle
4
+ import sqlite3
5
+ import torch
6
+ import torchvision
7
+ import torch.optim as optim
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import transformers
11
+
12
+ from model import RotationallyInvariantGPT, RotationallyInvariantGPTConfig
13
+ from prereqs.nanoGPT.model import GPTConfig, GPT, MLP
14
+ from datasets import load_from_disk
15
+ from torch.utils.data import DataLoader
16
+
17
+
18
+ from transformers import GPT2TokenizerFast
19
+
20
+ from torch.nn.utils.rnn import pad_sequence
21
+
22
+ def pad_collate(batch):
23
+ # Separating inputs and labels
24
+ inputs = [d['input_ids'] for d in batch]
25
+ labels = [d['labels'] for d in batch]
26
+
27
+ # Padding the input sequences
28
+ input_tensor = pad_sequence(inputs, batch_first=True)
29
+
30
+ # Padding the labels sequences
31
+ label_tensor = pad_sequence(labels, batch_first=True)
32
+
33
+ return {'input_ids': input_tensor, 'labels': label_tensor}
34
+
35
+ class DatabaseInterface(object):
36
+ def __init__(self, db_file):
37
+ self.db_file = db_file
38
+
39
+ def read(self, split):
40
+ conn = sqlite3.connect(self.db_file)
41
+ c = conn.cursor()
42
+ c.execute(f"SELECT * FROM plain_text WHERE split='{split}'")
43
+ col_names = [desc[0] for desc in c.description] # get column names
44
+ results = [dict(zip(col_names, row)) for row in c.fetchall()] # convert tuples to dictionaries
45
+ conn.close()
46
+ return results
47
+
48
+
49
+ class PlainTextDataset(torch.utils.data.Dataset):
50
+ def __init__(self, plain_text_dataset, tokenizer, device):
51
+ self.plain_text_dataset = plain_text_dataset
52
+ self.tokenizer = tokenizer
53
+ self.device = device
54
+
55
+ def __len__(self):
56
+ return len(self.plain_text_dataset)
57
+
58
+ def __getitem__(self, idx):
59
+ item = self.plain_text_dataset[idx]
60
+ tokens = self.tokenizer.encode_plus(item["text"], truncation=True, max_length=512, padding="max_length")
61
+ input_ids = tokens["input_ids"]
62
+ attention_mask = tokens["attention_mask"]
63
+ return {
64
+ 'input_ids': torch.as_tensor(input_ids[:-1], dtype=torch.long).to(self.device),
65
+ 'attention_mask': torch.as_tensor(attention_mask[:-1], dtype=torch.long).to(self.device),
66
+ 'labels': torch.as_tensor(input_ids[1:], dtype=torch.long).to(self.device)
67
+ }
68
+
69
+ def train(model: nn.Module, optimizer: optim.Optimizer, train_loader: DataLoader) -> float:
70
+ model.train()
71
+ running_loss = 0
72
+ for i, batch in enumerate(train_loader):
73
+ inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)
74
+ optimizer.zero_grad()
75
+ outputs, loss = model(inputs, targets)
76
+ loss.backward()
77
+ optimizer.step()
78
+ running_loss += loss.item()
79
+ if i % 100 == 0:
80
+ logging.info(f"Batch {i}: Loss={loss.item()}")
81
+ return running_loss / len(train_loader)
82
+
83
+
84
+ def evaluate(model, valid_loader) -> float:
85
+ model.eval()
86
+ running_loss = 0
87
+ with torch.no_grad():
88
+ for i, batch in enumerate(valid_loader):
89
+ inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)
90
+ outputs = model(inputs, targets)
91
+ loss = outputs.loss
92
+ running_loss += loss.item()
93
+ if i % 100 == 0:
94
+ logging.info(f"Batch {i}: Validation Loss={loss.item()}")
95
+ return running_loss / len(valid_loader)
96
+
97
+ if __name__ == '__main__':
98
+ logging.basicConfig(
99
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
100
+ datefmt='%m/%d/%Y %H:%M:%S',
101
+ level=logging.INFO
102
+ )
103
+ logging.info(f"PyTorch version: {torch.__version__}")
104
+ logging.info(f"Torchvision version: {torchvision.__version__}")
105
+ logging.info(f"Transformers version: {transformers.__version__}")
106
+ logging.info(f"CUDA version: {torch.version.cuda}")
107
+ logging.info(f"cuDNN version: {torch.backends.cudnn.version()}")
108
+
109
+ logging.info("Clearing cuda cache...")
110
+ torch.cuda.empty_cache()
111
+
112
+ logging.info("Setting num_threads to 1...")
113
+ torch.set_num_threads(1)
114
+
115
+ # Configs
116
+ d_model = 512
117
+ num_heads = 4
118
+ num_layers = 1
119
+ block_size = 512
120
+ dropout = 0.2
121
+ bias = True
122
+ rotational = True
123
+ batch_size = 32
124
+ eval_batch_size = 64
125
+ epochs = 10
126
+ lr = 0.001
127
+
128
+ vocab_size = 50304 # GPT-2 tokenizer vocab size
129
+ logging.info(f"Vocab size: {vocab_size}")
130
+
131
+ logging.info(f'''
132
+ Config:
133
+ d_model={d_model},
134
+ num_heads={num_heads},
135
+ num_layers={num_layers},
136
+ block_size={block_size},
137
+ dropout={dropout}, bias={bias}
138
+ '''
139
+ )
140
+ logging.info(
141
+ f"Training for {epochs} epochs with a learning rate of {lr}..."
142
+ )
143
+
144
+ logging.info(f"Batch size: {batch_size}")
145
+ logging.info(f"Eval batch size: {eval_batch_size}")
146
+
147
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148
+ # device = torch.device("cpu")
149
+ logging.info(f"Device: {device}")
150
+
151
+ logging.info("Loading tokenizer")
152
+ tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
153
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
154
+
155
+ # Query the database for the tokenized data
156
+ logging.info("Querying plain text data...")
157
+
158
+ db_file_path = "data/experiment1.db"
159
+
160
+ plain_text_train = DatabaseInterface(db_file_path).read("train")
161
+ #logging.debug(f"Plain text train: {plain_text_train[:10]}")
162
+
163
+ plain_text_val = DatabaseInterface(db_file_path).read("val")
164
+ #logging.debug(f"Plain text val: {plain_text_val[:10]}")
165
+
166
+ # Create train/val dataset objects
167
+ train_dataset = PlainTextDataset(plain_text_train, tokenizer, device)
168
+ valid_dataset = PlainTextDataset(plain_text_val, tokenizer, device)
169
+
170
+
171
+ # DEBUG
172
+ #for idx, item in enumerate(train_dataset):
173
+ # input_ids = item["input_ids"]
174
+ # attention_mask = item["attention_mask"]
175
+ # if input_ids.size(0) == 0:
176
+ # print(f"Sample index with 0 length: {idx}")
177
+ # print(f"Input_ids: {input_ids}")
178
+ # print(f"Attention_mask: {attention_mask}")
179
+
180
+ # Calculate the number of batches
181
+ num_train_batches = len(train_dataset) // batch_size
182
+ num_eval_batches = len(valid_dataset) // eval_batch_size
183
+
184
+
185
+ logging.info(f"Number of train batches: {num_train_batches}")
186
+ logging.info(f"Number of eval batches: {num_eval_batches}")
187
+
188
+ train_loader = DataLoader(
189
+ train_dataset,
190
+ batch_size=batch_size,
191
+ shuffle=False,
192
+ collate_fn=pad_collate
193
+ )
194
+
195
+ valid_loader = DataLoader(
196
+ valid_dataset,
197
+ batch_size=eval_batch_size,
198
+ shuffle=False,
199
+ collate_fn=pad_collate
200
+ )
201
+
202
+ # gpt_config = GPTConfig(
203
+ # vocab_size=vocab_size,
204
+ # n_embd=d_model,
205
+ # n_head=num_heads,
206
+ # n_layer=num_layers,
207
+ # block_size=block_size,
208
+ # dropout=dropout,
209
+ # bias=bias
210
+ #)
211
+
212
+ rigpt_config = RotationallyInvariantGPTConfig(
213
+ vocab_size=vocab_size,
214
+ n_embd=d_model,
215
+ n_head=num_heads,
216
+ n_layer=num_layers,
217
+ block_size=block_size,
218
+ dropout=dropout,
219
+ bias=bias,
220
+ rotational_invariance=rotational
221
+ )
222
+
223
+ logging.info("Creating models...")
224
+ # gpt = GPT(gpt_config).to(device)
225
+ rigpt = RotationallyInvariantGPT(rigpt_config).to(device)
226
+
227
+ logging.info("Creating optimizers...")
228
+ # optimizer_gpt = optim.Adam(gpt.parameters(), lr=lr)
229
+ optimizer_rigpt = optim.Adam(rigpt.parameters(), lr=lr)
230
+
231
+ logging.info("Training...")
232
+ for model, optimizer, model_name in [
233
+ # (
234
+ # gpt,
235
+ # optimizer_gpt,
236
+ # 'GPT'
237
+ # ),
238
+ (
239
+ rigpt,
240
+ optimizer_rigpt,
241
+ 'RotationallyInvariantGPT'
242
+ )
243
+ ]:
244
+ print(f"Training {model_name}")
245
+ for epoch in range(1, epochs + 1):
246
+ print(f"Training epoch {epoch}")
247
+ train_loss = train(model, optimizer, train_loader)
248
+ print(f"Validating epoch {epoch}")
249
+ valid_loss = evaluate(model, num_eval_batches)
250
+ print(
251
+ f'''
252
+ {model_name} -
253
+ Epoch: {epoch},
254
+ Train loss: {train_loss:.3f},
255
+ Validation loss: {valid_loss:.3f}'
256
+ '''
257
+ )
258
+
259
+ # torch.save(gpt.state_dict(), "gpt.pt")
260
+ torch.save(rigpt.state_dict(), "rigpt.pt")