Maykeye commited on
Commit
c9fc3d0
1 Parent(s): 9d689e0

Initial commit (without weights)

Browse files
Files changed (3) hide show
  1. README.md +31 -0
  2. mambabit.py +127 -0
  3. trainer.ipynb +196 -0
README.md CHANGED
@@ -1,3 +1,34 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ MambaBit. Bit-level cursed model with vocab size=2
6
+
7
+ * 4 layers, vocab size=2, embedded size = 4096 float32 parm per bit.
8
+
9
+ * Training was done on first 8030848 bits of tiny Shakespeare in 10 hours on laptop with 16GB VRAM on 9 batches of 128*8 bit each. Training code included in trainer.ipynb
10
+
11
+ * To run the model run `python mambabit.py "As sun raised over"`.
12
+ Expected output
13
+ ```
14
+ As sun raised over me.
15
+
16
+ LEONTES:
17
+ Now means means me not so much as my father,
18
+ In the good many lord, and my father come.
19
+
20
+ KING RICHARD III:
21
+ What is my father come and my father,
22
+ In the good lord, and my father come and before his father.
23
+
24
+ GLOUCESTER:
25
+ Now the goes of men, a
26
+ ```
27
+
28
+
29
+ * Bytes are encoded with most significant bit fed first. Eg '7' = [0, 0, 1, 1, 0, 1, 1, 1], so MSB 0 is being fed first
30
+ rather than last as if it was with [1, 1, 1, 0, 1, 1, 0, 0]. Intuition with that is that bits at the beginning change less frequent than in the end, so model will be like "I think I will produce a digit" then "I think I will produce 7" instead of "so I spat something. Should it be a number? a letter? dunno"
31
+
32
+ * I tried to use BF16 originally, but model went into nan (with default big LR) or gradients were so small weights didn't change(smaller LR). I switched back to F32, however some layers still initialize weight with factor x0.001 as I hoped it
33
+ would stop model from going to nan.
34
+
mambabit.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+ from mamba_ssm.modules.mamba_simple import Mamba
5
+ from mamba_ssm.utils.generation import InferenceParams
6
+ from tqdm.auto import tqdm
7
+ import sys
8
+ dim_model = 4096
9
+ n_vocab = 2
10
+ n_layers = 4
11
+
12
+
13
+ @torch.no_grad()
14
+ def string_to_bits(text: str, _cache = []) -> Tensor:
15
+ all_values = torch.arange(0, 256)
16
+ if not _cache:
17
+ bits = [((all_values & (1 << i)) != 0).int() for i in range(7, -1, -1)]
18
+ bits_tensor = torch.stack(bits).mT
19
+ _cache.append(bits_tensor)
20
+ else:
21
+ bits_tensor = _cache[0]
22
+ binary = text.encode()
23
+ raw = torch.frombuffer(binary, dtype=torch.uint8).int()
24
+ return bits_tensor[raw].long().ravel()
25
+
26
+ @torch.no_grad()
27
+ def bits_to_string(bits: Tensor):
28
+ if bits.dim() == 2:
29
+ return [bits_to_string(t) for t in bits]
30
+ assert bits.dim() == 1
31
+ assert len(bits) % 8 == 0
32
+ factors = torch.tensor([2**i for i in range(7,-1,-1)]).to(device=bits.device)
33
+ as_bytes = bits.view(-1, 8)
34
+ as_bytes = (as_bytes*factors).sum(-1)
35
+ return ''.join([chr(x) for x in as_bytes])
36
+
37
+ class Encoder(nn.Module):
38
+ def __init__(self):
39
+ super().__init__()
40
+ self.emb = nn.Embedding(n_vocab, dim_model)
41
+ self.emb.weight.data *= 0.001
42
+
43
+ def forward(self, x):
44
+ return self.emb(x)
45
+
46
+ class Decoder(nn.Module):
47
+ def __init__(self):
48
+ super().__init__()
49
+ self.norm = nn.LayerNorm(dim_model)
50
+ self.decoder = nn.Linear(dim_model, n_vocab, False)
51
+ self.decoder.weight.data *= 0.001
52
+
53
+ def forward(self, x):
54
+ x = self.norm(x)
55
+ x = self.decoder(x)
56
+ return x
57
+
58
+ class MambaBit(nn.Module):
59
+ def __init__(self):
60
+ super().__init__()
61
+ self.enc = Encoder()
62
+ self.layers = nn.ModuleList([Mamba(dim_model) for _ in range(n_layers)])
63
+ self.dec = Decoder()
64
+
65
+ def forward(self, x):
66
+ x = self.enc(x)
67
+ for layer in self.layers:
68
+ x = layer(x)
69
+ x = self.dec(x)
70
+ return x
71
+
72
+ class MambaBitWithInference(nn.Module):
73
+ def __init__(self):
74
+ super().__init__()
75
+ self.enc = Encoder()
76
+ self.layers = nn.ModuleList([Mamba(dim_model, layer_idx=i) for i in range(n_layers)])
77
+ self.dec = Decoder()
78
+
79
+ def forward(self, x, inference_parms=None):
80
+ x = self.enc(x)
81
+ for i,layer in enumerate(self.layers):
82
+ x = layer(x, inference_params=inference_parms)
83
+ x = self.dec(x)
84
+ return x
85
+
86
+ # test using O(N^2) cacheless stateless algorithm.
87
+ @torch.no_grad()
88
+ def test_n2(m: MambaBit, prompt: str, chars=10):
89
+ x = string_to_bits(prompt).cuda()[None]
90
+ process = chars * 8
91
+ for i in tqdm(range(process)):
92
+ y = m(x)
93
+ new = y[:, -1:].argmax(-1)
94
+ x = torch.cat((x, new), 1)
95
+ return bits_to_string(x)
96
+
97
+ # test using O(N) by reusing state
98
+ @torch.no_grad()
99
+ def test_n(m: MambaBit, prompt: str, chars=10):
100
+ x = string_to_bits(prompt).cuda()[None]
101
+ process = chars * 8
102
+
103
+ inference_parms = InferenceParams(
104
+ max_seqlen=x.numel() + process,
105
+ max_batch_size=1)
106
+
107
+ y = m(x, inference_parms=inference_parms)
108
+ new = y[:, -1:].argmax(-1)
109
+ for i in tqdm(range(process)):
110
+ x = torch.cat((x, new), 1)
111
+ inference_parms.seqlen_offset = x.numel() + i
112
+ y = m(new, inference_parms=inference_parms)
113
+ new = y[:, -1:].argmax(-1)
114
+ return bits_to_string(x)
115
+
116
+ def run():
117
+ mamba_bit = MambaBitWithInference().cuda()
118
+ mamba_bit.load_state_dict(torch.load("mamba_bit.bin"))
119
+
120
+
121
+ prompt="FIRST CITIZE" if len(sys.argv) != 2 else sys.argv[1]
122
+ # test_n2 is O(N^2), test_n is O(N) but inference_params are not well documented
123
+ s = test_n(mamba_bit, prompt, chars=256)[0]
124
+ print(s)
125
+
126
+ if __name__ == "__main__":
127
+ run()
trainer.ipynb ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import torch.nn as nn\n",
11
+ "from torch import Tensor\n",
12
+ "import random\n",
13
+ "from tqdm.auto import tqdm\n",
14
+ "from mamba_ssm.modules.mamba_simple import Mamba\n",
15
+ "\n",
16
+ "def model_numel(m: nn.Module):\n",
17
+ " return sum(p.numel() for p in m.parameters())"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "raw_txt = Path(\"../shake.txt\").read_text()\n",
27
+ "total_len = len(raw_txt)\n",
28
+ "aux_len = int(total_len * 0.05)\n",
29
+ "\n",
30
+ "head_txt, test_txt = raw_txt[:-aux_len], raw_txt[-aux_len:]\n",
31
+ "train_txt, valid_txt = head_txt[:-aux_len], head_txt[-aux_len:]"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "len(train_txt)"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "from mambabit import string_to_bits, bits_to_string\n",
50
+ "\n",
51
+ "train_ds = string_to_bits(train_txt)\n",
52
+ "valid_ds = string_to_bits(valid_txt)\n",
53
+ "test_ds = string_to_bits(test_txt)"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "def random_batches(split: Tensor, n_batch: int, bs: int):\n",
63
+ " assert bs % 8 == 0, \"have mercy\"\n",
64
+ " max_allowed_pos = len(split) // 8 - bs // 8\n",
65
+ "\n",
66
+ " values = []\n",
67
+ " for i in range(n_batch):\n",
68
+ " pos = random.randint(0, max_allowed_pos)\n",
69
+ " values.append(split[pos*8: pos*8+bs])\n",
70
+ " return torch.stack(values).cuda()"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "from mambabit import dim_model, n_vocab, n_layers, MambaBit"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "mamba_bit = MambaBit().cuda()"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "if True:\n",
98
+ " mamba_bit.load_state_dict(torch.load(\"mamba_bit.bin\"))"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "metadata": {},
105
+ "outputs": [],
106
+ "source": [
107
+ "\n",
108
+ "def train(m: nn.Module, \n",
109
+ " n_epoch: int = 100, \n",
110
+ " n_batch: int = 4, \n",
111
+ " bs: int = 256):\n",
112
+ " opt = torch.optim.AdamW(m.parameters(), lr=0.0001, fused=True)\n",
113
+ "\n",
114
+ " for e in (bar := tqdm(range(n_epoch))): \n",
115
+ " b = random_batches(train_ds, n_batch, bs)\n",
116
+ "\n",
117
+ " y_pred = m(b)\n",
118
+ " y_pred = y_pred[:, :-1].reshape(-1, n_vocab)\n",
119
+ " y_true = b[:, 1:].ravel()\n",
120
+ "\n",
121
+ " loss = F.cross_entropy(y_pred,y_true)\n",
122
+ " loss.backward()\n",
123
+ " opt.step()\n",
124
+ " opt.zero_grad()\n",
125
+ " \n",
126
+ " l = loss.item()\n",
127
+ " bar.set_description(f\"L:{l:.10f}\")\n",
128
+ "\n",
129
+ "\n",
130
+ "\n"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "if True:\n",
140
+ " train(mamba_bit, 5000, 9, 8*128)\n"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "torch.save(mamba_bit.state_dict(), \"mamba_bit.bin\")"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "# TEST\n",
159
+ "@torch.no_grad()\n",
160
+ "def test(prompt: str, chars=10):\n",
161
+ " x0 = decode_bits(prompt).cuda()[None]\n",
162
+ " x = x0.clone()\n",
163
+ " process = chars * 8\n",
164
+ " for _ in tqdm(range(process)):\n",
165
+ " y = mamba_bit(x)\n",
166
+ " new = y[:, -1:].argmax(-1)\n",
167
+ " x = torch.cat((x, new), 1) \n",
168
+ " return encode_bits(x)\n",
169
+ "\n",
170
+ " \n",
171
+ "print(test(\"FIRST CIT\", chars=10))"
172
+ ]
173
+ }
174
+ ],
175
+ "metadata": {
176
+ "kernelspec": {
177
+ "display_name": "sd",
178
+ "language": "python",
179
+ "name": "sd"
180
+ },
181
+ "language_info": {
182
+ "codemirror_mode": {
183
+ "name": "ipython",
184
+ "version": 3
185
+ },
186
+ "file_extension": ".py",
187
+ "mimetype": "text/x-python",
188
+ "name": "python",
189
+ "nbconvert_exporter": "python",
190
+ "pygments_lexer": "ipython3",
191
+ "version": "3.11.8"
192
+ }
193
+ },
194
+ "nbformat": 4,
195
+ "nbformat_minor": 2
196
+ }