shivendrra commited on
Commit
484d56b
1 Parent(s): 82ea75a

added model files

Browse files
enigma/EnBERT.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ this isn't a bert based model, i just liked the name and named it
3
+ --> decoder-only model, uses RMS normalization and GELU activation function
4
+ --> one masked-attention and other unmasked
5
+ --> attention layers have relational positional-embeddings
6
+ """
7
+
8
+ import json
9
+ with open('config.json', 'r', encoding='utf-8') as file:
10
+ params = json.load(file)
11
+
12
+ # required parameters
13
+ block_size = params['block_size']
14
+ d_model = params['d_model']
15
+ n_head = params['n_heads']
16
+ n_layers = params['n_layers']
17
+ learning_rate = params['learning_rate']
18
+ dropout = params['dropout']
19
+ norm_eps = params['norm_eps']
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import functional as F
24
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
+
26
+ class RMSNorm(nn.Module):
27
+ def __init__(self, dim: int, eps: float = 1e-6):
28
+ super().__init__()
29
+ self.eps = eps
30
+ self.weight = nn.Parameter(torch.ones(dim))
31
+
32
+ def _norm(self, x):
33
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
34
+
35
+ def forward(self, x):
36
+ output = self._norm(x.float()).type_as(x)
37
+ return output * self.weight
38
+
39
+ class SingleHead(nn.Module):
40
+ def __init__(self,
41
+ head_size: int,
42
+ d_model: int,
43
+ block_size: int,
44
+ dropout: float):
45
+ super().__init__()
46
+ self.key = nn.Linear(d_model, head_size, bias=True)
47
+ self.query = nn.Linear(d_model, head_size, bias=True)
48
+ self.value = nn.Linear(d_model, head_size, bias=True)
49
+ self.dropout = nn.Dropout(dropout)
50
+ self.rel_pos_embd = nn.Parameter(torch.randn(block_size, block_size, head_size))
51
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
52
+
53
+ def forward(self, x: torch.Tensor, mask: bool= False):
54
+ B, T, C = x.shape
55
+ key = self.key(x)
56
+ query = self.query(x)
57
+ scores = torch.matmul(query ,key.transpose(-2, -1)) / (key.shape[-1]**-0.5)
58
+
59
+ if mask is True:
60
+ scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
61
+
62
+ rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_embd[:T, :T])
63
+ scores = scores + rel_pos_scores
64
+
65
+ att_mat = F.softmax(scores, dim=-1)
66
+ att_mat = self.dropout(att_mat)
67
+ value = self.value(x)
68
+ output = torch.matmul(att_mat, value)
69
+ return output
70
+
71
+ class MultiHeadAttention(nn.Module):
72
+ def __init__(self,
73
+ d_model: int,
74
+ block_size: int,
75
+ n_head : int,
76
+ dropout: float):
77
+ head_size = d_model // n_head
78
+ super().__init__()
79
+ self.heads = nn.ModuleList([SingleHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])
80
+ self.projection = nn.Linear(d_model, d_model)
81
+ self.dropout = nn.Dropout(dropout)
82
+
83
+ def forward(self, x: torch.Tensor, mask: bool):
84
+ out = torch.cat([h(x, mask) for h in self.heads], dim=-1)
85
+ out = self.dropout(self.projection(out))
86
+ return out
87
+
88
+ class FeedForward(nn.Module):
89
+ def __init__(self, d_model, dropout):
90
+ super().__init__()
91
+ self.net = nn.Sequential(
92
+ nn.Linear(d_model, 5 * d_model),
93
+ nn.GELU(),
94
+ nn.Linear(5 * d_model, d_model),
95
+ nn.Dropout(dropout),
96
+ )
97
+
98
+ def forward(self, x: torch.Tensor):
99
+ return self.net(x)
100
+
101
+ class DecoderBlock(nn.Module):
102
+ def __init__(self, d_model: int,
103
+ block_size: int,
104
+ n_head: int,
105
+ norm_eps: float,
106
+ dropout: float):
107
+ super().__init__()
108
+ self.self_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
109
+ self.ffwd = FeedForward(d_model, dropout)
110
+ self.dropout = nn.Dropout(dropout)
111
+ self.norm = RMSNorm(d_model, eps=norm_eps)
112
+
113
+ def forward(self, x: torch.Tensor):
114
+ x_out = self.self_att(self.norm(x), mask=True)
115
+ x_out = x + self.dropout(x_out)
116
+ del x
117
+
118
+ x = self.self_att(self.norm(x_out, mask=False))
119
+ x = x_out + self.dropout(x)
120
+ del x_out
121
+
122
+ x_out = self.ffwd(self.norm(x))
123
+ x_out = x + self.dropout(x_out)
124
+ del x
125
+
126
+ return x_out
127
+
128
+ class Transformer(nn.Module):
129
+ def __init__(self, vocab_size: int):
130
+ super().__init__()
131
+ self.block_size = block_size
132
+ self.token_embeddings = nn.Embedding(vocab_size, d_model)
133
+ self.decoder = nn.Sequential(*[DecoderBlock(n_head=n_head, d_model=d_model, dropout=dropout, norm_eps=norm_eps, block_size=block_size) for _ in range(n_layers)])
134
+ self.norm_final = RMSNorm(d_model, eps=norm_eps)
135
+ self.linear_final = nn.Linear(d_model, vocab_size)
136
+ self.dropout = nn.Dropout(dropout)
137
+ self.apply(self._init_weights)
138
+
139
+ def _init_weights(self, module):
140
+ if isinstance(module, nn.Linear):
141
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
142
+ if module.bias is not None:
143
+ torch.nn.init.zeros_(module.bias.data)
144
+ elif isinstance(module, nn.Embedding):
145
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
146
+
147
+ def forward(self, idx, targets=None):
148
+ B, T = idx.shape
149
+ x = self.token_embeddings(idx)
150
+ x = self.decoder(x)
151
+ logits = self.linear_final(self.norm_final(x))
152
+
153
+ if targets is None:
154
+ loss = None
155
+
156
+ else:
157
+ B, T, C = logits.shape
158
+ logits = logits.view(B*T, C)
159
+ targets = targets.view(B*T)
160
+ loss = F.cross_entropy(logits, targets)
161
+
162
+ return logits, loss
163
+
164
+ @torch.no_grad()
165
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
166
+ self.eval()
167
+ for _ in range(max_new_tokens):
168
+
169
+ idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
170
+ logits, _ = self(idx_cond)
171
+ logits = logits[:, -1, :] / temperature
172
+
173
+ if top_k is not None:
174
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
175
+ logits[logits < v[:, [-1]]] = -float('Inf')
176
+
177
+ probs = F.softmax(logits, dim=-1)
178
+ idx_next = torch.multinomial(probs, num_samples=1)
179
+ idx = torch.cat((idx, idx_next), dim=1)
180
+
181
+ return idx
enigma/TrainEnigma.ipynb ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "vXIGN6PAuZWg"
7
+ },
8
+ "source": [
9
+ "### Train file for enigma model\n",
10
+ "\n",
11
+ "- Contains K-mer tokenizer, k=4, can be changed though\n",
12
+ "- Train data is available on huggingface repo: [hf/engima-1.5b](https://huggingface.co/shivendrra/enigma-1.5b)\n",
13
+ "- For now, trainig decoder-based model only\n",
14
+ "- More about this on github repo: [github/enigma-1.5b](https://github.com/shivendrra/enigma-1.5b)\n",
15
+ "- Saves model after training in '.pth' & '.safetensors' file for later use\n",
16
+ "- Generate function works fine"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {
23
+ "id": "WXpJBLyr30Rx"
24
+ },
25
+ "outputs": [],
26
+ "source": [
27
+ "from google.colab import drive\n",
28
+ "drive.mount('/content/drive')"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {
35
+ "id": "r7WUm0VL4bN4"
36
+ },
37
+ "outputs": [],
38
+ "source": [
39
+ "import torch\n",
40
+ "\n",
41
+ "# importing the data\n",
42
+ "file_path = '/content/drive/MyDrive/consolidated_dna.txt'\n",
43
+ "with open(file_path, 'r', encoding='utf-8') as file:\n",
44
+ " dna_seq = file.read()\n",
45
+ "file.close()\n",
46
+ "\n",
47
+ "print(f\"{(len(dna_seq)/1e6):.2f} million letters\")"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "metadata": {
54
+ "id": "Cdhybhz9owTK"
55
+ },
56
+ "outputs": [],
57
+ "source": [
58
+ "import os\n",
59
+ "from tqdm import tqdm\n",
60
+ "import json\n",
61
+ "\n",
62
+ "class KMerTokenizer:\n",
63
+ " def __init__(self, k_mers: int=4):\n",
64
+ " self.k_mers = k_mers\n",
65
+ " self.vocab = {}\n",
66
+ " self.id_to_token = []\n",
67
+ " self.token_to_id = {}\n",
68
+ "\n",
69
+ " def tokenize_sequence(self, sequence):\n",
70
+ " kmers = [sequence[i:i+self.k_mers] for i in tqdm(range(0, len(sequence), self.k_mers), desc=\"tokenizing k-mers\")]\n",
71
+ " return kmers\n",
72
+ "\n",
73
+ " def build_vocab(self, sequences):\n",
74
+ " all_kmers = []\n",
75
+ " for sequence in sequences:\n",
76
+ " all_kmers.extend(self.tokenize_sequence(sequence))\n",
77
+ " token_count = {}\n",
78
+ " for kmer in all_kmers:\n",
79
+ " if kmer in token_count:\n",
80
+ " token_count[kmer] += 1\n",
81
+ " else:\n",
82
+ " token_count[kmer] = 1\n",
83
+ " sorted_tokens = sorted(token_count.items(), key=lambda x: x[1], reverse=True)\n",
84
+ " for token, _ in sorted_tokens:\n",
85
+ " self.token_to_id[token] = len(self.token_to_id)\n",
86
+ " self.id_to_token.append(token)\n",
87
+ " self.vocab = self.token_to_id\n",
88
+ "\n",
89
+ " def encode(self, sequence):\n",
90
+ " encoded_sequence = []\n",
91
+ " kmers = self.tokenize_sequence(sequence)\n",
92
+ " for kmer in tqdm(kmers, desc=\"encoding sequences\"):\n",
93
+ " if kmer in self.token_to_id:\n",
94
+ " encoded_sequence.append(self.token_to_id[kmer])\n",
95
+ " else:\n",
96
+ " encoded_sequence.append(len(self.vocab))\n",
97
+ " return encoded_sequence\n",
98
+ "\n",
99
+ " def decode(self, encoded_sequence):\n",
100
+ " decoded_sequence = [self.id_to_token[token_id] for token_id in encoded_sequence]\n",
101
+ " return decoded_sequence\n",
102
+ "\n",
103
+ " def save_model(self, model_path):\n",
104
+ " vocab_file = f\"{model_path}/base_{self.k_mers}k.json\"\n",
105
+ " with open(vocab_file, 'w') as f:\n",
106
+ " json.dump(self.vocab, f)\n",
107
+ "\n",
108
+ " def load_model(self, path):\n",
109
+ " assert path.endswith('.json')\n",
110
+ " with open(path, 'r') as f:\n",
111
+ " vocab = json.load(f)\n",
112
+ "\n",
113
+ " self.vocab = vocab\n",
114
+ " self.token_to_id = self.vocab\n",
115
+ " self.vocab_size = len(vocab)"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "metadata": {
122
+ "id": "6BCpjdi5rjU4"
123
+ },
124
+ "outputs": [],
125
+ "source": [
126
+ "token = KMerTokenizer()\n",
127
+ "token.build_vocab([dna_seq])\n",
128
+ "print(f\"vocab size: {len(token.vocab)}\")\n",
129
+ "print(token.id_to_token[:10])"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "metadata": {
136
+ "id": "6Ou9txgmAdIB"
137
+ },
138
+ "outputs": [],
139
+ "source": [
140
+ "# Train and test splits\n",
141
+ "data = torch.tensor(token.encode(dna_seq), dtype=torch.long)\n",
142
+ "print(f\"{(len(data)/1e6):0f} million\"\")\n",
143
+ "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
144
+ "train_data = data[:n]\n",
145
+ "val_data = data[n:]\n",
146
+ "print(f\"train data {(len(train_data)/1e6):.0f}million, val data {(len(val_data)/1e6):.0f}million\")"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "metadata": {
153
+ "id": "ebFKQQ9NAq4e"
154
+ },
155
+ "outputs": [],
156
+ "source": [
157
+ "# hyperparams\n",
158
+ "batch_size = 10\n",
159
+ "block_size = 256\n",
160
+ "max_iters = 5000\n",
161
+ "eval_interval = 100\n",
162
+ "learning_rate = 3e-5\n",
163
+ "eval_iters = 100\n",
164
+ "d_model = 512\n",
165
+ "n_layers = 12\n",
166
+ "n_head = 18\n",
167
+ "dropout = 0.25\n",
168
+ "norm_eps = 1e-5"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {
175
+ "id": "dZMiYkr37cmU"
176
+ },
177
+ "outputs": [],
178
+ "source": [
179
+ "import torch.nn as nn\n",
180
+ "from torch.nn import functional as F\n",
181
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
182
+ "\n",
183
+ "class RMSNorm(nn.Module):\n",
184
+ " def __init__(self, dim: int, eps: float = 1e-6):\n",
185
+ " super().__init__()\n",
186
+ " self.eps = eps\n",
187
+ " self.weight = nn.Parameter(torch.ones(dim))\n",
188
+ "\n",
189
+ " def _norm(self, x):\n",
190
+ " return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n",
191
+ "\n",
192
+ " def forward(self, x):\n",
193
+ " output = self._norm(x.float()).type_as(x)\n",
194
+ " return output * self.weight\n",
195
+ "\n",
196
+ "class SingleHead(nn.Module):\n",
197
+ " def __init__(self,\n",
198
+ " head_size: int,\n",
199
+ " d_model: int,\n",
200
+ " block_size: int,\n",
201
+ " dropout: float):\n",
202
+ " super().__init__()\n",
203
+ " self.key = nn.Linear(d_model, head_size, bias=True)\n",
204
+ " self.query = nn.Linear(d_model, head_size, bias=True)\n",
205
+ " self.value = nn.Linear(d_model, head_size, bias=True)\n",
206
+ " self.dropout = nn.Dropout(dropout)\n",
207
+ " self.rel_pos_embd = nn.Parameter(torch.randn(block_size, block_size, head_size))\n",
208
+ " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
209
+ "\n",
210
+ " def forward(self, x: torch.Tensor, mask: bool= False):\n",
211
+ " B, T, C = x.shape\n",
212
+ " key = self.key(x)\n",
213
+ " query = self.query(x)\n",
214
+ " scores = torch.matmul(query ,key.transpose(-2, -1)) / (key.shape[-1]**-0.5)\n",
215
+ "\n",
216
+ " if mask is True:\n",
217
+ " scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))\n",
218
+ "\n",
219
+ " rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_embd[:T, :T])\n",
220
+ " scores = scores + rel_pos_scores\n",
221
+ "\n",
222
+ " att_mat = F.softmax(scores, dim=-1)\n",
223
+ " att_mat = self.dropout(att_mat)\n",
224
+ " value = self.value(x)\n",
225
+ " output = torch.matmul(att_mat, value)\n",
226
+ " return output\n",
227
+ "\n",
228
+ "class MultiHeadAttention(nn.Module):\n",
229
+ " def __init__(self,\n",
230
+ " d_model: int,\n",
231
+ " block_size: int,\n",
232
+ " n_head : int,\n",
233
+ " dropout: float):\n",
234
+ " head_size = d_model // n_head\n",
235
+ " super().__init__()\n",
236
+ " self.heads = nn.ModuleList([SingleHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])\n",
237
+ " self.projection = nn.Linear(d_model, d_model)\n",
238
+ " self.dropout = nn.Dropout(dropout)\n",
239
+ "\n",
240
+ " def forward(self, x: torch.Tensor, mask: bool):\n",
241
+ " out = torch.cat([h(x, mask) for h in self.heads], dim=-1)\n",
242
+ " out = self.dropout(self.projection(out))\n",
243
+ " return out\n",
244
+ "\n",
245
+ "class FeedForward(nn.Module):\n",
246
+ " def __init__(self, d_model, dropout):\n",
247
+ " super().__init__()\n",
248
+ " self.net = nn.Sequential(\n",
249
+ " nn.Linear(d_model, 5 * d_model),\n",
250
+ " nn.GELU(),\n",
251
+ " nn.Linear(5 * d_model, d_model),\n",
252
+ " nn.Dropout(dropout),\n",
253
+ " )\n",
254
+ "\n",
255
+ " def forward(self, x: torch.Tensor):\n",
256
+ " return self.net(x)\n",
257
+ "\n",
258
+ "class DecoderBlock(nn.Module):\n",
259
+ " def __init__(self, d_model: int,\n",
260
+ " block_size: int,\n",
261
+ " n_head: int,\n",
262
+ " norm_eps: float,\n",
263
+ " dropout: float):\n",
264
+ " super().__init__()\n",
265
+ " self.self_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)\n",
266
+ " self.ffwd = FeedForward(d_model, dropout)\n",
267
+ " self.dropout = nn.Dropout(dropout)\n",
268
+ " self.norm = RMSNorm(d_model, eps=norm_eps)\n",
269
+ "\n",
270
+ " def forward(self, x: torch.Tensor):\n",
271
+ " x_out = self.self_att(self.norm(x), mask=True)\n",
272
+ " x_out = x + self.dropout(x_out)\n",
273
+ " del x\n",
274
+ "\n",
275
+ " x = self.self_att(self.norm(x_out, mask=False))\n",
276
+ " x = x_out + self.dropout(x)\n",
277
+ " del x_out\n",
278
+ "\n",
279
+ " x_out = self.ffwd(self.norm(x))\n",
280
+ " x_out = x + self.dropout(x_out)\n",
281
+ " del x\n",
282
+ "\n",
283
+ " return x_out\n",
284
+ "\n",
285
+ "class Transformer(nn.Module):\n",
286
+ " def __init__(self, vocab_size: int):\n",
287
+ " super().__init__()\n",
288
+ " self.block_size = block_size\n",
289
+ " self.token_embeddings = nn.Embedding(vocab_size, d_model)\n",
290
+ " self.decoder = nn.Sequential(*[DecoderBlock(n_head=n_head, d_model=d_model, dropout=dropout, norm_eps=norm_eps, block_size=block_size) for _ in range(n_layers)])\n",
291
+ " self.norm_final = RMSNorm(d_model, eps=norm_eps)\n",
292
+ " self.linear_final = nn.Linear(d_model, vocab_size)\n",
293
+ " self.dropout = nn.Dropout(dropout)\n",
294
+ " self.apply(self._init_weights)\n",
295
+ "\n",
296
+ " def _init_weights(self, module):\n",
297
+ " if isinstance(module, nn.Linear):\n",
298
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
299
+ " if module.bias is not None:\n",
300
+ " torch.nn.init.zeros_(module.bias.data)\n",
301
+ " elif isinstance(module, nn.Embedding):\n",
302
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
303
+ "\n",
304
+ " def forward(self, idx, targets=None):\n",
305
+ " B, T = idx.shape\n",
306
+ " x = self.token_embeddings(idx)\n",
307
+ " x = self.decoder(x)\n",
308
+ " logits = self.linear_final(self.norm_final(x))\n",
309
+ "\n",
310
+ " if targets is None:\n",
311
+ " loss = None\n",
312
+ "\n",
313
+ " else:\n",
314
+ " B, T, C = logits.shape\n",
315
+ " logits = logits.view(B*T, C)\n",
316
+ " targets = targets.view(B*T)\n",
317
+ " loss = F.cross_entropy(logits, targets)\n",
318
+ "\n",
319
+ " return logits, loss\n",
320
+ "\n",
321
+ " @torch.no_grad()\n",
322
+ " def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):\n",
323
+ " self.eval()\n",
324
+ " for _ in range(max_new_tokens):\n",
325
+ "\n",
326
+ " idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]\n",
327
+ " logits, _ = self(idx_cond)\n",
328
+ " logits = logits[:, -1, :] / temperature\n",
329
+ "\n",
330
+ " if top_k is not None:\n",
331
+ " v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n",
332
+ " logits[logits < v[:, [-1]]] = -float('Inf')\n",
333
+ "\n",
334
+ " probs = F.softmax(logits, dim=-1)\n",
335
+ " idx_next = torch.multinomial(probs, num_samples=1)\n",
336
+ " idx = torch.cat((idx, idx_next), dim=1)\n",
337
+ "\n",
338
+ " return idx"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": null,
344
+ "metadata": {
345
+ "id": "X9VOBZFr7g3W"
346
+ },
347
+ "outputs": [],
348
+ "source": [
349
+ "import timeit\n",
350
+ "start_time = timeit.default_timer()\n",
351
+ "\n",
352
+ "def get_batch(split):\n",
353
+ " data = train_data if split == 'train' else val_data\n",
354
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
355
+ " x = torch.stack([data[i:i+block_size] for i in ix])\n",
356
+ " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
357
+ " x, y = x.to(device), y.to(device)\n",
358
+ " return x, y\n",
359
+ "\n",
360
+ "@torch.no_grad()\n",
361
+ "def estimate_loss():\n",
362
+ " out = {}\n",
363
+ " model.eval()\n",
364
+ " for split in ['train', 'val']:\n",
365
+ " losses = torch.zeros(eval_iters)\n",
366
+ " for k in range(eval_iters):\n",
367
+ " X, Y = get_batch(split)\n",
368
+ " logits, loss = model(X, Y)\n",
369
+ " losses[k] = loss.item()\n",
370
+ " out[split] = losses.mean()\n",
371
+ " model.train()\n",
372
+ " return out\n",
373
+ "\n",
374
+ "vocab_size = len(token.vocab)\n",
375
+ "model = Transformer(vocab_size)\n",
376
+ "# checkpoint_path = '/content/drive/MyDrive/enigma-2.5b.pth'\n",
377
+ "# checkpoint = torch.load(checkpoint_path)\n",
378
+ "# model.load_state_dict(checkpoint)\n",
379
+ "m = model.to(device)\n",
380
+ "\n",
381
+ "# no of parameters\n",
382
+ "n_param = sum(p.numel() for p in m.parameters())/1e6\n",
383
+ "print(f\"{n_param:.1f} million parameters\")\n",
384
+ "\n",
385
+ "# optimizer\n",
386
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
387
+ "steps = []\n",
388
+ "train_losses = []\n",
389
+ "val_losses = []\n",
390
+ "\n",
391
+ "for iter in range(max_iters):\n",
392
+ "\n",
393
+ " if iter % eval_interval == 0 or iter == max_iters - 1:\n",
394
+ " losses = estimate_loss()\n",
395
+ " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
396
+ "\n",
397
+ " steps.append(iter)\n",
398
+ " train_losses.append(losses['train'])\n",
399
+ " val_losses.append(losses['val'])\n",
400
+ "\n",
401
+ " xb, yb = get_batch('train')\n",
402
+ " logits, loss = model(xb, yb)\n",
403
+ " optimizer.zero_grad(set_to_none=True)\n",
404
+ " loss.backward()\n",
405
+ " optimizer.step()"
406
+ ]
407
+ },
408
+ {
409
+ "cell_type": "code",
410
+ "execution_count": null,
411
+ "metadata": {
412
+ "id": "tzJMKoA35uIV"
413
+ },
414
+ "outputs": [],
415
+ "source": [
416
+ "end_time = timeit.default_timer()\n",
417
+ "print(f\"total parameters: {n_param:.1f} billion\")\n",
418
+ "print(f\"trained in {((end_time - start_time)/3600):.2f}hrs\")"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": null,
424
+ "metadata": {
425
+ "id": "eB47Yn9aNrrO"
426
+ },
427
+ "outputs": [],
428
+ "source": [
429
+ "model_save_name = f'consolidated_00.pth'\n",
430
+ "path = f\"/content/drive/MyDrive/{model_save_name}\"\n",
431
+ "torch.save(model.state_dict(), path)\n",
432
+ "\n",
433
+ "# saving safe-tensors\n",
434
+ "from safetensors.torch import save_file\n",
435
+ "\n",
436
+ "model_save_name = f'consolidated_00.safetensors'\n",
437
+ "path = f\"/content/drive/MyDrive/{model_save_name}\"\n",
438
+ "save_file(model.state_dict(), path)"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": null,
444
+ "metadata": {
445
+ "id": "89TNah_89CRB"
446
+ },
447
+ "outputs": [],
448
+ "source": [
449
+ "!nvidia-smi"
450
+ ]
451
+ }
452
+ ],
453
+ "metadata": {
454
+ "accelerator": "GPU",
455
+ "colab": {
456
+ "gpuType": "T4",
457
+ "machine_shape": "hm",
458
+ "provenance": []
459
+ },
460
+ "kernelspec": {
461
+ "display_name": "Python 3",
462
+ "name": "python3"
463
+ },
464
+ "language_info": {
465
+ "name": "python"
466
+ }
467
+ },
468
+ "nbformat": 4,
469
+ "nbformat_minor": 0
470
+ }
enigma/config_enigma.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "batch_size": 10,
3
+ "block_size": 512,
4
+ "max_iters": 5000,
5
+ "eval_interval": 50,
6
+ "learning_rate": 3e-5,
7
+ "eval_iters": 100,
8
+ "d_model": 384,
9
+ "n_head": 12,
10
+ "n_layer": 12,
11
+ "dropout": 0.2,
12
+ "norm_eps": 1e-5
13
+ }
enigma/enigma.cpp ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/torch.h>
2
+ #include <iostream>
3
+ #include <vector>
4
+
5
+ // Define device
6
+ torch::Device device(torch::kCUDA);
7
+
8
+ // Define constants
9
+ const int batch_size = 8;
10
+ const int block_size = 32;
11
+ const int max_iters = 1000;
12
+ const int eval_interval = 50;
13
+ const int eval_iters = 5;
14
+ const int d_model = 256;
15
+ const int n_layer = 16;
16
+ const int n_head = 12;
17
+ const float dropout = 0.2;
18
+ const float norm_eps = 1e-5;
19
+ const int vocab_size = 5;
20
+
21
+ // sample data
22
+ torch::Tensor train_data = torch::rand({1000, block_size});
23
+ torch::Tensor val_data = torch::rand({500, block_size});
24
+
25
+ // Data loading function
26
+ std::pair<torch::Tensor, torch::Tensor> get_batch(const std::string& split) {
27
+ torch::Tensor data = (split == "train") ? train_data : val_data;
28
+ torch::Tensor ix = torch::randint(data.size(0) - block_size, {batch_size});
29
+ torch::Tensor x = torch::empty({batch_size, block_size});
30
+ torch::Tensor y = torch::empty({batch_size, block_size});
31
+ for (int i = 0; i < batch_size; ++i) {
32
+ x[i] = data.index({ix[i], ix[i] + block_size});
33
+ y[i] = data.index({ix[i] + 1, ix[i] + block_size + 1});
34
+ }
35
+ return std::make_pair(x.to(device), y.to(device));
36
+ }
37
+
38
+ // Custom classes and functions
39
+ class SWiGLU : public torch::nn::Module {
40
+ public:
41
+ SWiGLU() {}
42
+
43
+ torch::Tensor forward(torch::Tensor x) {
44
+ torch::Tensor sigmoid_output = torch::sigmoid(x);
45
+ torch::Tensor relu_output = torch::relu(x);
46
+ torch::Tensor out = sigmoid_output * relu_output + (1 - sigmoid_output) * x;
47
+ return out;
48
+ }
49
+ };
50
+
51
+ class UnMaskedHeadImpl : public torch::nn::Module {
52
+ public:
53
+ UnMaskedHeadImpl(int d_model, int head_size, float dropout)
54
+ : key(register_module("key", torch::nn::Linear(d_model, head_size))),
55
+ query(register_module("query", torch::nn::Linear(d_model, head_size))),
56
+ value(register_module("value", torch::nn::Linear(d_model, head_size))),
57
+ dropout(torch::nn::Dropout(dropout)) {
58
+ register_module("dropout", dropout);
59
+ }
60
+
61
+ torch::Tensor forward(torch::Tensor x) {
62
+ torch::Tensor key_out = key->forward(x);
63
+ torch::Tensor query_out = query->forward(x);
64
+
65
+ torch::Tensor weights = query_out.matmul(key_out.transpose(-2, -1)) * std::sqrt(key_out.size(-1));
66
+ weights = torch::softmax(weights, -1);
67
+ weights = dropout(weights);
68
+
69
+ torch::Tensor value_out = value->forward(x);
70
+ torch::Tensor out = weights.matmul(value_out);
71
+ return out;
72
+ }
73
+
74
+ private:
75
+ torch::nn::Linear key, query, value;
76
+ torch::nn::Dropout dropout;
77
+ };
78
+
79
+ TORCH_MODULE(UnMaskedHead);
80
+
81
+ class MaskedHeadImpl : public torch::nn::Module {
82
+ public:
83
+ MaskedHeadImpl(int head_size, float dropout, int d_model)
84
+ : key(register_module("key", torch::nn::Linear(d_model, head_size))),
85
+ query(register_module("query", torch::nn::Linear(d_model, head_size))),
86
+ value(register_module("value", torch::nn::Linear(d_model, head_size))),
87
+ dropout(torch::nn::Dropout(dropout)) {
88
+ register_buffer("tril", torch::tril(torch::ones(block_size, block_size)));
89
+ }
90
+
91
+ torch::Tensor forward(torch::Tensor x) {
92
+ torch::Tensor key_out = key->forward(x);
93
+ torch::Tensor query_out = query->forward(x);
94
+
95
+ torch::Tensor weights = query_out.matmul(key_out.transpose(-2, -1)) * std::sqrt(key_out.size(-1));
96
+ weights = weights.masked_fill(tril[:x.size(1), :x.size(1)] == 0, std::numeric_limits<float>::lowest());
97
+ weights = torch::softmax(weights, -1);
98
+ weights = dropout(weights);
99
+
100
+ torch::Tensor value_out = value->forward(x);
101
+ torch::Tensor out = weights.matmul(value_out);
102
+ return out;
103
+ }
104
+
105
+ private:
106
+ torch::nn::Linear key, query, value;
107
+ torch::nn::Dropout dropout;
108
+ torch::Tensor tril;
109
+ };
110
+
111
+ TORCH_MODULE(MaskedHead);
112
+
113
+ class MultiUnMaskedImpl : public torch::nn::Module {
114
+ public:
115
+ MultiUnMaskedImpl(int d_model, int n_head, float dropout)
116
+ : proj(register_module("proj", torch::nn::Linear(n_head * (d_model / n_head), d_model))),
117
+ dropout(torch::nn::Dropout(dropout)) {
118
+ for (int i = 0; i < n_head; ++i) {
119
+ heads.push_back(register_module("head" + std::to_string(i), UnMaskedHead(d_model, d_model / n_head, dropout)));
120
+ }
121
+ }
122
+
123
+ torch::Tensor forward(torch::Tensor x) {
124
+ std::vector<torch::Tensor> head_outputs;
125
+ for (auto& head : heads) {
126
+ head_outputs.push_back(head->forward(x));
127
+ }
128
+ torch::Tensor out = torch::cat(head_outputs, -1);
129
+ out = dropout(out);
130
+ out = proj(out);
131
+ return out;
132
+ }
133
+
134
+ private:
135
+ torch::nn::Linear proj;
136
+ torch::nn::Dropout dropout;
137
+ std::vector<UnMaskedHead> heads;
138
+ };
139
+
140
+ TORCH_MODULE(MultiUnMasked);
141
+
142
+ class MultiMaskedImpl : public torch::nn::Module {
143
+ public:
144
+ MultiMaskedImpl(int d_model, int n_head, float dropout)
145
+ : proj(register_module("proj", torch::nn::Linear(n_head * (d_model / n_head), d_model))),
146
+ dropout(torch::nn::Dropout(dropout)) {
147
+ for (int i = 0; i < n_head; ++i) {
148
+ heads.push_back(register_module("head" + std::to_string(i), MaskedHead(d_model, d_model / n_head, dropout)));
149
+ }
150
+ }
151
+
152
+ torch::Tensor forward(torch::Tensor x) {
153
+ std::vector<torch::Tensor> head_outputs;
154
+ for (auto& head : heads) {
155
+ head_outputs.push_back(head->forward(x));
156
+ }
157
+ torch::Tensor out = torch::cat(head_outputs, -1);
158
+ out = dropout(out);
159
+ out = proj(out);
160
+ return out;
161
+ }
162
+
163
+ private:
164
+ torch::nn::Linear proj;
165
+ torch::nn::Dropout dropout;
166
+ std::vector<MaskedHead> heads;
167
+ };
168
+
169
+ TORCH_MODULE(MultiMasked);
170
+
171
+ class FeedForwardImpl : public torch::nn::Module {
172
+ public:
173
+ FeedForwardImpl(int d_model, float dropout)
174
+ : net(register_module("net", torch::nn::Sequential(
175
+ torch::nn::Linear(d_model, 4 * d_model),
176
+ torch::nn::GELU(),
177
+ torch::nn::Linear(4 * d_model, d_model),
178
+ torch::nn::Dropout(dropout)
179
+ ))) {}
180
+
181
+ torch::Tensor forward(torch::Tensor x) {
182
+ return net->forward(x);
183
+ }
184
+
185
+ private:
186
+ torch::nn::Sequential net;
187
+ };
188
+
189
+ TORCH_MODULE(FeedForward);
190
+
191
+ class BlockImpl : public torch::nn::Module {
192
+ public:
193
+ BlockImpl(int d_model, int n_head, float norm_eps, float dropout)
194
+ : sa_masked(MultiMasked(d_model, n_head, dropout)),
195
+ sa_unmasked(MultiUnMasked(d_model, n_head, dropout)),
196
+ ffwd(FeedForward(d_model, dropout)),
197
+ norm1(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))),
198
+ norm2(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))) {}
199
+
200
+ torch::Tensor forward(torch::Tensor x) {
201
+ torch::Tensor x2 = x + sa_unmasked->forward(norm1->forward(x));
202
+ x = x2 + ffwd->forward(norm2->forward(x2));
203
+
204
+ x2 = x + sa_masked->forward(norm1->forward(x));
205
+ x = x2 + ffwd->forward(norm2->forward(x2));
206
+
207
+ return x;
208
+ }
209
+
210
+ private:
211
+ MultiMasked sa_masked;
212
+ MultiUnMasked sa_unmasked;
213
+ FeedForward ffwd;
214
+ torch::nn::LayerNorm norm1, norm2;
215
+ };
216
+
217
+ TORCH_MODULE(Block);
218
+
219
+ class EnigmaImpl : public torch::nn::Module {
220
+ public:
221
+ EnigmaImpl(int vocab_size, int block_size, int d_model, int n_layer, int n_head, float dropout, float norm_eps)
222
+ : toked_model(register_module("toked_model", torch::nn::Embedding(vocab_size, d_model))),
223
+ pos_encod(register_module("pos_encod", torch::nn::Embedding(block_size, d_model))),
224
+ norm_final(torch::nn::LayerNorm(torch::nn::LayerNormOptions({d_model}).eps(norm_eps))),
225
+ linear_final(register_module("linear_final", torch::nn::Linear(d_model, vocab_size))) {
226
+ for (int i = 0; i < n_layer; ++i) {
227
+ block_layers.push_back(register_module("block" + std::to_string(i), Block(d_model, n_head, norm_eps, dropout)));
228
+ }
229
+ register_buffer("block_size", torch::tensor(block_size));
230
+ _init_weights(this);
231
+ }
232
+
233
+ void _init_weights(torch::nn::Module* module) {
234
+ auto parameters = module->named_parameters();
235
+ for (auto& param : parameters) {
236
+ if (param.key().find("weight") != std::string::npos) {
237
+ torch::nn::init::normal_(param.value(), 0.0, 0.02);
238
+ } else if (param.key().find("bias") != std::string::npos) {
239
+ torch::nn::init::zeros_(param.value());
240
+ }
241
+ }
242
+ }
243
+
244
+ std::pair<torch::Tensor, torch::Tensor> forward(torch::Tensor idx, torch::Tensor targets=torch::Tensor()) {
245
+ torch::Tensor toked_model_out = toked_model->forward(idx);
246
+ torch::Tensor pos_encod_out = pos_encod->forward(torch::arange(idx.size(1)));
247
+ torch::Tensor x = toked_model_out + pos_encod_out;
248
+
249
+ for (auto& block : block_layers) {
250
+ x = block->forward(x);
251
+ }
252
+
253
+ torch::Tensor logits = linear_final->forward(norm_final->forward(x));
254
+
255
+ if (!targets.numel()) {
256
+ return {logits, torch::Tensor()};
257
+ } else {
258
+ logits = logits.view({-1, logits.size(-1)});
259
+ targets = targets.view({-1});
260
+ torch::Tensor loss = torch::nn::functional::cross_entropy(logits, targets);
261
+ return {logits, loss};
262
+ }
263
+ }
264
+
265
+ std::vector<std::vector<std::pair<torch::Tensor, float>>> complex_generate(torch::Tensor idx, int max_new_tokens, float temperature=1.0, int top_k=3, int beam_width=5) {
266
+ std::vector<std::vector<std::pair<torch::Tensor, float>>> completed_beams;
267
+ torch::Tensor current_idx = idx.clone();
268
+ std::vector<std::pair<torch::Tensor, float>> beam = {std::make_pair(current_idx, 0.0)};
269
+
270
+ for (int i = 0; i < max_new_tokens; ++i) {
271
+ std::vector<std::pair<torch::Tensor, float>> new_beam;
272
+
273
+ for (auto& beam_item : beam) {
274
+ torch::Tensor& current_idx = beam_item.first;
275
+ torch::Tensor logits, loss;
276
+ std::tie(logits, loss) = forward(current_idx);
277
+ logits = logits.index({torch::indexing::Slice(), -1}); // Get last token predictions
278
+
279
+ // Apply softmax and temperature
280
+ torch::Tensor probs = torch::nn::functional::softmax(logits / temperature, -1);
281
+
282
+ // Top-k sampling
283
+ if (top_k > 0) {
284
+ probs = top_k_filtering(probs, top_k);
285
+ }
286
+
287
+ // Sample from the distribution
288
+ torch::Tensor sampled_idx = torch::multinomial(probs, beam_width, true);
289
+
290
+ for (int j = 0; j < beam_width; ++j) {
291
+ torch::Tensor new_idx = torch::cat({current_idx, sampled_idx.index({torch::indexing::Slice(), j})}, 1);
292
+ torch::Tensor new_log_prob = beam_item.second + torch::log(probs.index({torch::indexing::Slice(), sampled_idx.index({torch::indexing::Slice(), j})}));
293
+ new_beam.push_back(std::make_pair(new_idx, new_log_prob.item()));
294
+ }
295
+ }
296
+
297
+ // Sort new beam by log probabilities
298
+ std::sort(new_beam.begin(), new_beam.end(), [](const std::pair<torch::Tensor, float>& a, const std::pair<torch::Tensor, float>& b) {
299
+ return a.second > b.second;
300
+ });
301
+
302
+ // Only keep top beams
303
+ beam = std::vector<std::pair<torch::Tensor, float>>(new_beam.begin(), new_beam.begin() + beam_width);
304
+ }
305
+
306
+ completed_beams.push_back(beam);
307
+ return completed_beams;
308
+ }
309
+
310
+ std::vector<std::vector<std::pair<torch::Tensor, float>>> top_k_filtering(torch::Tensor logits, int top_k) {
311
+ torch::Tensor top_values, top_indices;
312
+ std::tie(top_values, top_indices) = torch::topk(logits, top_k, -1);
313
+
314
+ torch::Tensor min_value = torch::index_select(top_values, -1, torch::tensor({top_k-1}));
315
+ torch::Tensor filtered_logits = torch::where(logits < min_value, torch::full_like(logits, -std::numeric_limits<float>::infinity()), logits);
316
+ return filtered_logits;
317
+ }
318
+
319
+ private:
320
+ torch::nn::Embedding toked_model, pos_encod;
321
+ std::vector<Block> block_layers;
322
+ torch::nn::LayerNorm norm_final;
323
+ torch::nn::Linear linear_final;
324
+ int block_size;
325
+ };
326
+
327
+ TORCH_MODULE(Enigma);
328
+
329
+ int main() {
330
+ // Set seed
331
+ torch::manual_seed(1400);
332
+
333
+ // Create model
334
+ Enigma model(vocab_size, block_size, d_model, n_layer, n_head, dropout, norm_eps);
335
+ model->to(device);
336
+
337
+ // Define optimizer
338
+ torch::optim::AdamW optimizer(model->parameters(), torch::optim::AdamWOptions(learning_rate));
339
+
340
+ // Training loop
341
+ std::vector<float> train_losses, val_losses;
342
+ for (int iter = 0; iter < max_iters; ++iter) {
343
+ if (iter % eval_interval == 0 || iter == max_iters - 1) {
344
+ // Evaluate and print losses
345
+ auto losses = estimate_loss();
346
+ std::cout << "step " << iter << ": train loss " << losses["train"] << ", val loss " << losses["val"] << std::endl;
347
+
348
+ // Save losses for plotting
349
+ train_losses.push_back(losses["train"]);
350
+ val_losses.push_back(losses["val"]);
351
+ }
352
+
353
+ // Get batch, forward pass, loss calculation, backward pass, optimizer step
354
+ auto [xb, yb] = get_batch("train");
355
+ torch::Tensor logits, loss;
356
+ std::tie(logits, loss) = model->forward(xb, yb);
357
+
358
+ optimizer.zero_grad();
359
+ loss.backward();
360
+ optimizer.step();
361
+ }
362
+
363
+ return 0;
364
+ }
enigma/generate.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ current_directory = os.path.dirname(os.path.abspath(__file__))
3
+ os.chdir(current_directory)
4
+
5
+ with open('../parquet files/new_dna.txt', 'r', encoding='utf-8') as file:
6
+ captions = file.read()
7
+
8
+ print(f"{(len(captions)/1e6):.2f} million letters")
9
+
10
+ from tokenizer import PerCharTokenizer
11
+
12
+ tokenizer = PerCharTokenizer()
13
+ vocab_size = tokenizer.vocab_size
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.nn import functional as F
18
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+
20
+ from model import Transformer
21
+ model = Transformer(vocab_size=vocab_size)
22
+
23
+ class Generator(Transformer):
24
+ def __init__(self, vocab_size):
25
+ super().__init__()
26
+ self.vocab_size = vocab_size
27
+ self.block_size = Transformer.block_size
28
+
29
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
30
+ """
31
+ generate new tokens using the trained model
32
+
33
+ Args:
34
+ - idx (Tensor): input tensor representing initial token indices
35
+ - max_new_tokens (int): max no of new tokens to generate
36
+ - temperature (float): softmax temperature for sampling
37
+ - top_k (int): no of top tokens to consider in sampling
38
+
39
+ Returns:
40
+ - generated_tokens (list): list of generated token indices
41
+ """
42
+ generated_tokens = []
43
+
44
+ for _ in range(max_new_tokens):
45
+ idx_cond = idx[:, -self.block_size:]
46
+ logits, _ = self(idx_cond)
47
+ logits = logits[:, -1, :]
48
+
49
+ scaled_logits = logits / temperature
50
+ if top_k > 0:
51
+ scaled_logits = self._top_k_filtering(scaled_logits, top_k)
52
+
53
+ probs = F.softmax(scaled_logits, dim=-1)
54
+ sampled_idx = torch.multinomial(probs, num_samples=1)
55
+ generated_tokens.append(sampled_idx.item())
56
+ idx = torch.cat((idx, sampled_idx), dim=1)
57
+
58
+ return generated_tokens
59
+
60
+ def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):
61
+ """
62
+ Generate predictions for masked tokens using the trained model.
63
+
64
+ Args:
65
+ - idx (Tensor): input tensor representing token indices
66
+ - masked_indices (Tensor): tensor of indices indicating masked positions
67
+ - temperature (float): softmax temperature for sampling
68
+ - top_k (int): no of top tokens to consider in sampling
69
+
70
+ Returns:
71
+ - predicted_tokens (Tensor): tensor of predicted token indices
72
+ """
73
+ B, T = idx.shape
74
+
75
+ toked_model = self.toked_model(idx)
76
+ pos_encod = self.pos_encod(torch.arange(T, device=device))
77
+ x = toked_model + pos_encod
78
+
79
+ for layer in self.enc_layer:
80
+ x_out = layer(x)
81
+
82
+ for layer in self.dec_layer:
83
+ x_final = layer(x, x_out)
84
+
85
+ x_masked = x_final.clone()
86
+ x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))
87
+
88
+ x_masked = self.norm_final(x_masked)
89
+ logits = self.linear_final(x_masked)
90
+
91
+ masked_logits = logits[masked_indices].view(-1, logits.size(-1))
92
+ scaled_logits = masked_logits / temperature
93
+ if top_k > 0:
94
+ scaled_logits = self._top_k_filtering(scaled_logits, top_k)
95
+
96
+ probs = F.softmax(scaled_logits, dim=-1)
97
+ predicted_indices = torch.argmax(probs, dim=-1)
98
+
99
+ return predicted_indices
100
+
101
+ def _top_k_filtering(self, logits, top_k):
102
+ """
103
+ filter logits to keep only the top-k tokens
104
+
105
+ Args:
106
+ - logits (Tensor): input tensor representing unscaled logits
107
+ - top_k (int): no of top tokens to keep
108
+
109
+ Returns:
110
+ - filtered_logits (Tensor): filtered logits with only top-k tokens remaining
111
+ """
112
+ values, indices = torch.topk(logits, top_k, dim=-1)
113
+ min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
114
+ filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
115
+
116
+ return filtered_logits
117
+
118
+ checkpoint_path = '../trained models/enigma_47m.pth'
119
+ checkpoint = torch.load(checkpoint_path)
120
+ model.load_state_dict(checkpoint)
121
+ m = model.to(device)
122
+
123
+ target_text = "AGTTCTGCGAT"
124
+ context = torch.tensor([tokenizer.encode(target_text)], dtype=torch.long, device=device)
125
+ generated_output = tokenizer.decode(Generator.generate(context, max_new_tokens=10, temperature=0.5, top_k=5))
126
+ print(f"{target_text}{generated_output}")
enigma/model.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ transformer based model, but with few minimal tweaks
3
+ trained a 2.5billion parameters model with current set configurations
4
+ """
5
+
6
+ import torch
7
+ import json
8
+ import os
9
+ current_directory = os.path.dirname(os.path.abspath(__file__))
10
+ os.chdir(current_directory)
11
+
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+
15
+ with open('config_enigma.json', 'r', encoding='utf-8') as file:
16
+ params = json.load(file)
17
+
18
+ batch_size = params['batch_size']
19
+ block_size = params['block_size']
20
+ n_head = params['n_head']
21
+ d_model = params['d_model']
22
+ n_layers = params['n_layer']
23
+ dropout = params['dropout']
24
+ norm_eps = params['norm_eps']
25
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
26
+
27
+ class AttentionHead(nn.Module):
28
+ """
29
+ initialize a single head of self attention.
30
+
31
+ Args:
32
+ - d_model (int): dimensionality of the model's hidden layers
33
+ - head_size (int): dimensionality of each attention head
34
+ - dropout (float): dropout probability
35
+ - block_size (int): the maximum sequence length for positional encoding
36
+ """
37
+ def __init__(self, d_model, head_size, dropout, block_size):
38
+ super().__init__()
39
+ self.key = nn.Linear(d_model, head_size, bias=True)
40
+ self.query = nn.Linear(d_model, head_size, bias=True)
41
+ self.value = nn.Linear(d_model, head_size, bias=False)
42
+ self.dropout = nn.Dropout(dropout)
43
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
44
+
45
+ self.rel_pos_emb = nn.Parameter(torch.randn(block_size, block_size, head_size))
46
+
47
+ def forward(self, x, mask=False):
48
+ """
49
+ forward pass of a single attention head.
50
+
51
+ Args:
52
+ - x (Tensor): input tensor.
53
+ - mask (bool): flag indicating whether to apply masking
54
+
55
+ Returns:
56
+ - out (Tensor): output tensor after self attention
57
+ """
58
+ B, T, C = x.shape
59
+ key = self.key(x)
60
+ query = self.query(x)
61
+
62
+ scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)
63
+ rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_emb[:T, :T])
64
+ scores += rel_pos_scores
65
+
66
+ if mask:
67
+ scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
68
+
69
+ weights = F.softmax(scores, dim=-1)
70
+ weights = self.dropout(weights)
71
+
72
+ value = self.value(x)
73
+ out = torch.matmul(weights, value)
74
+ return out
75
+
76
+ class MultiHeadAttention(nn.Module):
77
+ """
78
+ initialize a multi-head attention module.
79
+
80
+ Args:
81
+ - d_model (int): dimensionality of the model's hidden layers
82
+ - n_head (int): no of attention heads
83
+ - dropout (float): dropout probability
84
+ - block_size (int): context length
85
+ """
86
+ def __init__(self, d_model, n_head, dropout, block_size):
87
+ head_size = d_model // n_head
88
+ super().__init__()
89
+ self.heads = nn.ModuleList([AttentionHead(d_model=d_model, dropout=dropout, head_size=head_size, block_size=block_size) for _ in range(n_head)])
90
+ self.proj = nn.Linear(n_head * head_size, d_model)
91
+ self.dropout = nn.Dropout(dropout)
92
+
93
+ def forward(self, x, mask):
94
+ """
95
+ forward pass of the multi-head attention module
96
+
97
+ Args:
98
+ - x (Tensor): input tensor
99
+ - mask (bool): flag indicating whether to apply masking
100
+
101
+ Returns:
102
+ - out (Tensor): output tensor after multi-head attention
103
+
104
+ """
105
+ out = torch.cat([h(x, mask=mask) for h in self.heads], dim=-1)
106
+ out = self.dropout(self.proj(out))
107
+ return out
108
+
109
+ class FeedForward(nn.Module):
110
+ """
111
+ initialize a feedforward network module
112
+
113
+ Args:
114
+ - d_model (int): the dimensionality of the model's hidden layers
115
+ - dropout (float): dropout probability
116
+
117
+ """
118
+ def __init__(self, d_model, dropout):
119
+ super().__init__()
120
+ self.net = nn.Sequential(
121
+ nn.Linear(d_model, 10*d_model),
122
+ nn.GELU(),
123
+ nn.Linear(10*d_model, d_model),
124
+ nn.Dropout(dropout)
125
+ )
126
+
127
+ def forward(self, x):
128
+ """
129
+ forward pass of the feedforward network module
130
+
131
+ Args:
132
+ - x (Tensor): input tensor
133
+
134
+ Returns:
135
+ - out (Tensor): output tensor after passing through the feedforward network
136
+ """
137
+ return self.net(x)
138
+
139
+ class EncoderNetwork(nn.Module):
140
+ """
141
+ initialize an encoder network module
142
+
143
+ Args:
144
+ - d_model (int): dimensionality of the model's hidden layers
145
+ - n_head (int): no of attention heads in multi-head attention layers
146
+ - norm_eps (float): epsilon value for layer normalization
147
+ - dropout (float): dropout probability
148
+ - block_size (int): the maximum sequence length for positional encoding
149
+ """
150
+ def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
151
+ super().__init__()
152
+ self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
153
+ self.ffwd = FeedForward(d_model, dropout)
154
+ self.dropout = nn.Dropout(dropout)
155
+ self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)
156
+ self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)
157
+
158
+ def forward(self, src):
159
+ """
160
+ forward pass of the encoder network module.
161
+
162
+ Args:
163
+ - src (Tensor): input tensor representing source data
164
+
165
+ Returns:
166
+ - src (Tensor): output tensor after passing through the encoder network
167
+ """
168
+ src2 = self.s_att(src, mask=False)
169
+ src = src + self.dropout(src2)
170
+ src = self.norm1(src)
171
+
172
+ src2 = self.ffwd(src)
173
+ src = src + self.dropout(src2)
174
+ src = self.norm2(src)
175
+
176
+ return src
177
+
178
+ class DecoderNetwork(nn.Module):
179
+ """
180
+ initialize a decoder network module
181
+
182
+ Args:
183
+ - d_model (int): dimensionality of the model's hidden layers
184
+ - n_head (int): no of attention heads in multi-head attention layers
185
+ - norm_eps (float): epsilon value for layer normalization
186
+ - dropout (float): dropout probability
187
+ - block_size (int): the maximum sequence length for positional encoding
188
+ """
189
+ def __init__(self, d_model, n_head, norm_eps, dropout, block_size):
190
+ super().__init__()
191
+ self.s_att = MultiHeadAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)
192
+ self.ffwd = FeedForward(d_model, dropout)
193
+ self.dropout = nn.Dropout(dropout)
194
+ self.norm1 = nn.LayerNorm(d_model, eps=norm_eps)
195
+ self.norm2 = nn.LayerNorm(d_model, eps=norm_eps)
196
+
197
+ def forward(self, src, att):
198
+ """
199
+ forward pass of the decoder network module.
200
+
201
+ Args:
202
+ - src (Tensor): input tensor, same as the encoder's inputs
203
+ - trg (Tensor): encoder's attention matrix
204
+
205
+ Returns:
206
+ - src_f (Tensor): final output tensor
207
+ """
208
+ src2 = self.s_att(src, mask=True)
209
+ src = src + self.dropout(src2)
210
+ src = src + self.norm1(src)
211
+
212
+ att = src + att
213
+ att2 = self.s_att(att, mask=False)
214
+ att2 = att + self.dropout(att2)
215
+ trg = att2 + self.norm1(att2)
216
+
217
+ src_f2 = self.ffwd(self.norm2(trg))
218
+ src_f = src_f + self.dropout(src_f2)
219
+ src_f = self.norm2(src_f)
220
+
221
+ return src_f
222
+
223
+ class Transformer(nn.Module):
224
+ """
225
+ initialize a Transformer model
226
+
227
+ Args:
228
+ - vocab_size (int): size of the vocabulary
229
+ - d_model (int): dimensionality of the model's hidden layers
230
+ - block_size (int): maximum sequence length for positional encoding/context length
231
+ - n_layers (int): number of encoder and decoder layers in the Transformer
232
+ - n_head (int): number of attention heads in multi-head attention layers
233
+ - norm_eps (float): epsilon value for layer normalization
234
+ - dropout (float): dropout probability
235
+ """
236
+ def __init__(self, vocab_size):
237
+ super().__init__()
238
+ self.block_size = block_size
239
+ self.toked_model = nn.Embedding(vocab_size, d_model)
240
+ self.pos_encod = nn.Embedding(block_size, d_model)
241
+ self.enc_layer = nn.ModuleList([EncoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
242
+ self.dec_layer = nn.ModuleList([DecoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])
243
+
244
+ self.norm_final = nn.LayerNorm(d_model)
245
+ self.linear_final = nn.Linear(d_model, vocab_size)
246
+ self.dropout = nn.Dropout(dropout)
247
+ self.apply(self._init_weights)
248
+
249
+ def _init_weights(self, module):
250
+ """
251
+ initialize weights of linear and embedding layers
252
+
253
+ Args:
254
+ - module (nn.Module): the module to initialize weights for
255
+ """
256
+ if isinstance(module, nn.Linear):
257
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
258
+ if module.bias is not None:
259
+ torch.nn.init.zeros_(module.bias.data)
260
+ elif isinstance(module, nn.Embedding):
261
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
262
+
263
+ def forward(self, idx, targets=None):
264
+ """
265
+ forward pass of the transformer model
266
+
267
+ Args:
268
+ - idx (Tensor): input tensor representing token indices
269
+ - targets (Tensor): target tensor for computing loss during training
270
+
271
+ Returns:
272
+ - logits (Tensor): output logits from the final linear layer
273
+ - loss (Tensor): optional. computed cross-entropy loss if targets are provided, else None
274
+ """
275
+ B, T = idx.shape
276
+
277
+ toked_model = self.toked_model(idx)
278
+ pos_encod = self.pos_encod(torch.arange(T, device=device))
279
+ x = toked_model + pos_encod
280
+
281
+ for layer in self.enc_layer:
282
+ x_out = layer(x)
283
+
284
+ for layer in self.dec_layer:
285
+ x_final = layer(x, x_out)
286
+
287
+ x_final = self.norm_final(x_final)
288
+ logits = self.linear_final(x_final)
289
+
290
+ if targets is None:
291
+ loss = None
292
+
293
+ else:
294
+ B, T, C = logits.shape
295
+ logits = logits.view(B*T, C)
296
+ targets = targets.view(B*T)
297
+ loss = F.cross_entropy(logits, targets)
298
+
299
+ return logits, loss
300
+
301
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):
302
+ """
303
+ generate new tokens using the trained model
304
+
305
+ Args:
306
+ - idx (Tensor): input tensor representing initial token indices
307
+ - max_new_tokens (int): max no of new tokens to generate
308
+ - temperature (float): softmax temperature for sampling
309
+ - top_k (int): no of top tokens to consider in sampling
310
+
311
+ Returns:
312
+ - generated_tokens (list): list of generated token indices
313
+ """
314
+ generated_tokens = []
315
+
316
+ for _ in range(max_new_tokens):
317
+ idx_cond = idx[:, -self.block_size:]
318
+ logits, _ = self(idx_cond)
319
+ logits = logits[:, -1, :]
320
+
321
+ scaled_logits = logits / temperature
322
+ if top_k > 0:
323
+ scaled_logits = self._top_k_filtering(scaled_logits, top_k)
324
+
325
+ probs = F.softmax(scaled_logits, dim=-1)
326
+ sampled_idx = torch.multinomial(probs, num_samples=1)
327
+ generated_tokens.append(sampled_idx.item())
328
+ idx = torch.cat((idx, sampled_idx), dim=1)
329
+
330
+ return generated_tokens
331
+
332
+ def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):
333
+ """
334
+ Generate predictions for masked tokens using the trained model.
335
+
336
+ Args:
337
+ - idx (Tensor): input tensor representing token indices
338
+ - masked_indices (Tensor): tensor of indices indicating masked positions
339
+ - temperature (float): softmax temperature for sampling
340
+ - top_k (int): no of top tokens to consider in sampling
341
+
342
+ Returns:
343
+ - predicted_tokens (Tensor): tensor of predicted token indices
344
+ """
345
+ B, T = idx.shape
346
+
347
+ toked_model = self.toked_model(idx)
348
+ pos_encod = self.pos_encod(torch.arange(T, device=device))
349
+ x = toked_model + pos_encod
350
+
351
+ for layer in self.enc_layer:
352
+ x_out = layer(x)
353
+
354
+ for layer in self.dec_layer:
355
+ x_final = layer(x, x_out)
356
+
357
+ x_masked = x_final.clone()
358
+ x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))
359
+
360
+ x_masked = self.norm_final(x_masked)
361
+ logits = self.linear_final(x_masked)
362
+
363
+ masked_logits = logits[masked_indices].view(-1, logits.size(-1))
364
+ scaled_logits = masked_logits / temperature
365
+ if top_k > 0:
366
+ scaled_logits = self._top_k_filtering(scaled_logits, top_k)
367
+
368
+ probs = F.softmax(scaled_logits, dim=-1)
369
+ predicted_indices = torch.argmax(probs, dim=-1)
370
+
371
+ return predicted_indices
372
+
373
+ def _top_k_filtering(self, logits, top_k):
374
+ """
375
+ filter logits to keep only the top-k tokens
376
+
377
+ Args:
378
+ - logits (Tensor): input tensor representing unscaled logits
379
+ - top_k (int): no of top tokens to keep
380
+
381
+ Returns:
382
+ - filtered_logits (Tensor): filtered logits with only top-k tokens remaining
383
+ """
384
+ values, indices = torch.topk(logits, top_k, dim=-1)
385
+ min_value = values[:, -1].unsqueeze(-1).expand_as(logits)
386
+ filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)
387
+
388
+ return filtered_logits
enigma/run.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ use this file to train the model
3
+
4
+ working:
5
+ - imports vatious dependencies first, and then loads the training data
6
+ - tokenizes it, per-character basis
7
+ - loads the required hyper-parameters and the model file
8
+ - trains it till 'max_iters' and saves the model state, and generates outputs
9
+
10
+ with the current set configuration, model can reach upto ~60million parameters
11
+ and can become ~99% accurate with next token prediction
12
+ """
13
+
14
+ import torch
15
+ import json
16
+ import os
17
+ current_directory = os.path.dirname(os.path.abspath(__file__))
18
+ os.chdir(current_directory)
19
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+
21
+ with open('../parquet files/new_dna.txt', 'r', encoding='utf-8') as file:
22
+ captions = file.read()
23
+
24
+ print(f"{(len(captions)/1e6):.2f} million letters")
25
+
26
+ from ..tokenizer import PerCharTokenizer
27
+
28
+ tokenizer = PerCharTokenizer()
29
+ vocab_size = tokenizer.vocab_size
30
+ # Train and test splits
31
+ data = torch.tensor(tokenizer.encode(captions), dtype=torch.long)
32
+ n = int(0.9*len(data)) # first 90% will be train, rest val
33
+ train_data = data[:n]
34
+ val_data = data[n:]
35
+
36
+ with open('/config_enigma.json', 'r', encoding='utf-8') as file:
37
+ params = json.load(file)
38
+
39
+ # required parameters
40
+ batch_size = params['batch_size']
41
+ block_size = params['block_size']
42
+ max_iters = params['max_iters']
43
+ eval_interval = params['eval_interval']
44
+ eval_iters = params['eval_iters']
45
+ learning_rate = params['learning_rate']
46
+
47
+ torch.manual_seed(1400)
48
+ # data loading
49
+ def get_batch(split):
50
+ # generate a small batch of data of inputs x and targets y
51
+ data = train_data if split == 'train' else val_data
52
+ ix = torch.randint(len(data) - block_size, (batch_size,))
53
+ x = torch.stack([data[i:i+block_size] for i in ix])
54
+ y = torch.stack([data[i+1:i+block_size+1] for i in ix])
55
+ x, y = x.to(device), y.to(device)
56
+ return x, y
57
+
58
+ @torch.no_grad()
59
+ def estimate_loss():
60
+ out = {}
61
+ model.eval()
62
+ for split in ['train', 'val']:
63
+ losses = torch.zeros(eval_iters)
64
+ for k in range(eval_iters):
65
+ X, Y = get_batch(split)
66
+ logits, loss = model(X, Y)
67
+ losses[k] = loss.item()
68
+ out[split] = losses.mean()
69
+ model.train()
70
+ return out
71
+
72
+ from model import Transformer
73
+ model = Transformer(vocab_size=vocab_size)
74
+ m = model.to(device)
75
+
76
+ # no of parameters
77
+ n_param = sum(p.numel() for p in m.parameters())/1e6
78
+ print(f"{n_param:.2f} million")
79
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
80
+ steps = []
81
+ train_losses = []
82
+ val_losses = []
83
+
84
+ for iter in range(max_iters):
85
+
86
+ if iter % eval_interval == 0 or iter == max_iters - 1:
87
+ losses = estimate_loss()
88
+ print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
89
+
90
+ steps.append(iter)
91
+ train_losses.append(losses['train'])
92
+ val_losses.append(losses['val'])
93
+
94
+ xb, yb = get_batch('train')
95
+ logits, loss = model(xb, yb)
96
+ optimizer.zero_grad(set_to_none=True)
97
+ loss.backward()
98
+ optimizer.step()
99
+
100
+ torch.save(model.state_dict(), f'enigma_{n_param:.0f}m.pth')