Maykeye
commited on
Commit
•
c9fc3d0
1
Parent(s):
9d689e0
Initial commit (without weights)
Browse files- README.md +31 -0
- mambabit.py +127 -0
- trainer.ipynb +196 -0
README.md
CHANGED
@@ -1,3 +1,34 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
---
|
4 |
+
|
5 |
+
MambaBit. Bit-level cursed model with vocab size=2
|
6 |
+
|
7 |
+
* 4 layers, vocab size=2, embedded size = 4096 float32 parm per bit.
|
8 |
+
|
9 |
+
* Training was done on first 8030848 bits of tiny Shakespeare in 10 hours on laptop with 16GB VRAM on 9 batches of 128*8 bit each. Training code included in trainer.ipynb
|
10 |
+
|
11 |
+
* To run the model run `python mambabit.py "As sun raised over"`.
|
12 |
+
Expected output
|
13 |
+
```
|
14 |
+
As sun raised over me.
|
15 |
+
|
16 |
+
LEONTES:
|
17 |
+
Now means means me not so much as my father,
|
18 |
+
In the good many lord, and my father come.
|
19 |
+
|
20 |
+
KING RICHARD III:
|
21 |
+
What is my father come and my father,
|
22 |
+
In the good lord, and my father come and before his father.
|
23 |
+
|
24 |
+
GLOUCESTER:
|
25 |
+
Now the goes of men, a
|
26 |
+
```
|
27 |
+
|
28 |
+
|
29 |
+
* Bytes are encoded with most significant bit fed first. Eg '7' = [0, 0, 1, 1, 0, 1, 1, 1], so MSB 0 is being fed first
|
30 |
+
rather than last as if it was with [1, 1, 1, 0, 1, 1, 0, 0]. Intuition with that is that bits at the beginning change less frequent than in the end, so model will be like "I think I will produce a digit" then "I think I will produce 7" instead of "so I spat something. Should it be a number? a letter? dunno"
|
31 |
+
|
32 |
+
* I tried to use BF16 originally, but model went into nan (with default big LR) or gradients were so small weights didn't change(smaller LR). I switched back to F32, however some layers still initialize weight with factor x0.001 as I hoped it
|
33 |
+
would stop model from going to nan.
|
34 |
+
|
mambabit.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch import Tensor
|
4 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
5 |
+
from mamba_ssm.utils.generation import InferenceParams
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
import sys
|
8 |
+
dim_model = 4096
|
9 |
+
n_vocab = 2
|
10 |
+
n_layers = 4
|
11 |
+
|
12 |
+
|
13 |
+
@torch.no_grad()
|
14 |
+
def string_to_bits(text: str, _cache = []) -> Tensor:
|
15 |
+
all_values = torch.arange(0, 256)
|
16 |
+
if not _cache:
|
17 |
+
bits = [((all_values & (1 << i)) != 0).int() for i in range(7, -1, -1)]
|
18 |
+
bits_tensor = torch.stack(bits).mT
|
19 |
+
_cache.append(bits_tensor)
|
20 |
+
else:
|
21 |
+
bits_tensor = _cache[0]
|
22 |
+
binary = text.encode()
|
23 |
+
raw = torch.frombuffer(binary, dtype=torch.uint8).int()
|
24 |
+
return bits_tensor[raw].long().ravel()
|
25 |
+
|
26 |
+
@torch.no_grad()
|
27 |
+
def bits_to_string(bits: Tensor):
|
28 |
+
if bits.dim() == 2:
|
29 |
+
return [bits_to_string(t) for t in bits]
|
30 |
+
assert bits.dim() == 1
|
31 |
+
assert len(bits) % 8 == 0
|
32 |
+
factors = torch.tensor([2**i for i in range(7,-1,-1)]).to(device=bits.device)
|
33 |
+
as_bytes = bits.view(-1, 8)
|
34 |
+
as_bytes = (as_bytes*factors).sum(-1)
|
35 |
+
return ''.join([chr(x) for x in as_bytes])
|
36 |
+
|
37 |
+
class Encoder(nn.Module):
|
38 |
+
def __init__(self):
|
39 |
+
super().__init__()
|
40 |
+
self.emb = nn.Embedding(n_vocab, dim_model)
|
41 |
+
self.emb.weight.data *= 0.001
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return self.emb(x)
|
45 |
+
|
46 |
+
class Decoder(nn.Module):
|
47 |
+
def __init__(self):
|
48 |
+
super().__init__()
|
49 |
+
self.norm = nn.LayerNorm(dim_model)
|
50 |
+
self.decoder = nn.Linear(dim_model, n_vocab, False)
|
51 |
+
self.decoder.weight.data *= 0.001
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
x = self.norm(x)
|
55 |
+
x = self.decoder(x)
|
56 |
+
return x
|
57 |
+
|
58 |
+
class MambaBit(nn.Module):
|
59 |
+
def __init__(self):
|
60 |
+
super().__init__()
|
61 |
+
self.enc = Encoder()
|
62 |
+
self.layers = nn.ModuleList([Mamba(dim_model) for _ in range(n_layers)])
|
63 |
+
self.dec = Decoder()
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
x = self.enc(x)
|
67 |
+
for layer in self.layers:
|
68 |
+
x = layer(x)
|
69 |
+
x = self.dec(x)
|
70 |
+
return x
|
71 |
+
|
72 |
+
class MambaBitWithInference(nn.Module):
|
73 |
+
def __init__(self):
|
74 |
+
super().__init__()
|
75 |
+
self.enc = Encoder()
|
76 |
+
self.layers = nn.ModuleList([Mamba(dim_model, layer_idx=i) for i in range(n_layers)])
|
77 |
+
self.dec = Decoder()
|
78 |
+
|
79 |
+
def forward(self, x, inference_parms=None):
|
80 |
+
x = self.enc(x)
|
81 |
+
for i,layer in enumerate(self.layers):
|
82 |
+
x = layer(x, inference_params=inference_parms)
|
83 |
+
x = self.dec(x)
|
84 |
+
return x
|
85 |
+
|
86 |
+
# test using O(N^2) cacheless stateless algorithm.
|
87 |
+
@torch.no_grad()
|
88 |
+
def test_n2(m: MambaBit, prompt: str, chars=10):
|
89 |
+
x = string_to_bits(prompt).cuda()[None]
|
90 |
+
process = chars * 8
|
91 |
+
for i in tqdm(range(process)):
|
92 |
+
y = m(x)
|
93 |
+
new = y[:, -1:].argmax(-1)
|
94 |
+
x = torch.cat((x, new), 1)
|
95 |
+
return bits_to_string(x)
|
96 |
+
|
97 |
+
# test using O(N) by reusing state
|
98 |
+
@torch.no_grad()
|
99 |
+
def test_n(m: MambaBit, prompt: str, chars=10):
|
100 |
+
x = string_to_bits(prompt).cuda()[None]
|
101 |
+
process = chars * 8
|
102 |
+
|
103 |
+
inference_parms = InferenceParams(
|
104 |
+
max_seqlen=x.numel() + process,
|
105 |
+
max_batch_size=1)
|
106 |
+
|
107 |
+
y = m(x, inference_parms=inference_parms)
|
108 |
+
new = y[:, -1:].argmax(-1)
|
109 |
+
for i in tqdm(range(process)):
|
110 |
+
x = torch.cat((x, new), 1)
|
111 |
+
inference_parms.seqlen_offset = x.numel() + i
|
112 |
+
y = m(new, inference_parms=inference_parms)
|
113 |
+
new = y[:, -1:].argmax(-1)
|
114 |
+
return bits_to_string(x)
|
115 |
+
|
116 |
+
def run():
|
117 |
+
mamba_bit = MambaBitWithInference().cuda()
|
118 |
+
mamba_bit.load_state_dict(torch.load("mamba_bit.bin"))
|
119 |
+
|
120 |
+
|
121 |
+
prompt="FIRST CITIZE" if len(sys.argv) != 2 else sys.argv[1]
|
122 |
+
# test_n2 is O(N^2), test_n is O(N) but inference_params are not well documented
|
123 |
+
s = test_n(mamba_bit, prompt, chars=256)[0]
|
124 |
+
print(s)
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
run()
|
trainer.ipynb
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import torch\n",
|
10 |
+
"import torch.nn as nn\n",
|
11 |
+
"from torch import Tensor\n",
|
12 |
+
"import random\n",
|
13 |
+
"from tqdm.auto import tqdm\n",
|
14 |
+
"from mamba_ssm.modules.mamba_simple import Mamba\n",
|
15 |
+
"\n",
|
16 |
+
"def model_numel(m: nn.Module):\n",
|
17 |
+
" return sum(p.numel() for p in m.parameters())"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "code",
|
22 |
+
"execution_count": null,
|
23 |
+
"metadata": {},
|
24 |
+
"outputs": [],
|
25 |
+
"source": [
|
26 |
+
"raw_txt = Path(\"../shake.txt\").read_text()\n",
|
27 |
+
"total_len = len(raw_txt)\n",
|
28 |
+
"aux_len = int(total_len * 0.05)\n",
|
29 |
+
"\n",
|
30 |
+
"head_txt, test_txt = raw_txt[:-aux_len], raw_txt[-aux_len:]\n",
|
31 |
+
"train_txt, valid_txt = head_txt[:-aux_len], head_txt[-aux_len:]"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": null,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"len(train_txt)"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": null,
|
46 |
+
"metadata": {},
|
47 |
+
"outputs": [],
|
48 |
+
"source": [
|
49 |
+
"from mambabit import string_to_bits, bits_to_string\n",
|
50 |
+
"\n",
|
51 |
+
"train_ds = string_to_bits(train_txt)\n",
|
52 |
+
"valid_ds = string_to_bits(valid_txt)\n",
|
53 |
+
"test_ds = string_to_bits(test_txt)"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": null,
|
59 |
+
"metadata": {},
|
60 |
+
"outputs": [],
|
61 |
+
"source": [
|
62 |
+
"def random_batches(split: Tensor, n_batch: int, bs: int):\n",
|
63 |
+
" assert bs % 8 == 0, \"have mercy\"\n",
|
64 |
+
" max_allowed_pos = len(split) // 8 - bs // 8\n",
|
65 |
+
"\n",
|
66 |
+
" values = []\n",
|
67 |
+
" for i in range(n_batch):\n",
|
68 |
+
" pos = random.randint(0, max_allowed_pos)\n",
|
69 |
+
" values.append(split[pos*8: pos*8+bs])\n",
|
70 |
+
" return torch.stack(values).cuda()"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "code",
|
75 |
+
"execution_count": null,
|
76 |
+
"metadata": {},
|
77 |
+
"outputs": [],
|
78 |
+
"source": [
|
79 |
+
"from mambabit import dim_model, n_vocab, n_layers, MambaBit"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": null,
|
85 |
+
"metadata": {},
|
86 |
+
"outputs": [],
|
87 |
+
"source": [
|
88 |
+
"mamba_bit = MambaBit().cuda()"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": null,
|
94 |
+
"metadata": {},
|
95 |
+
"outputs": [],
|
96 |
+
"source": [
|
97 |
+
"if True:\n",
|
98 |
+
" mamba_bit.load_state_dict(torch.load(\"mamba_bit.bin\"))"
|
99 |
+
]
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "code",
|
103 |
+
"execution_count": null,
|
104 |
+
"metadata": {},
|
105 |
+
"outputs": [],
|
106 |
+
"source": [
|
107 |
+
"\n",
|
108 |
+
"def train(m: nn.Module, \n",
|
109 |
+
" n_epoch: int = 100, \n",
|
110 |
+
" n_batch: int = 4, \n",
|
111 |
+
" bs: int = 256):\n",
|
112 |
+
" opt = torch.optim.AdamW(m.parameters(), lr=0.0001, fused=True)\n",
|
113 |
+
"\n",
|
114 |
+
" for e in (bar := tqdm(range(n_epoch))): \n",
|
115 |
+
" b = random_batches(train_ds, n_batch, bs)\n",
|
116 |
+
"\n",
|
117 |
+
" y_pred = m(b)\n",
|
118 |
+
" y_pred = y_pred[:, :-1].reshape(-1, n_vocab)\n",
|
119 |
+
" y_true = b[:, 1:].ravel()\n",
|
120 |
+
"\n",
|
121 |
+
" loss = F.cross_entropy(y_pred,y_true)\n",
|
122 |
+
" loss.backward()\n",
|
123 |
+
" opt.step()\n",
|
124 |
+
" opt.zero_grad()\n",
|
125 |
+
" \n",
|
126 |
+
" l = loss.item()\n",
|
127 |
+
" bar.set_description(f\"L:{l:.10f}\")\n",
|
128 |
+
"\n",
|
129 |
+
"\n",
|
130 |
+
"\n"
|
131 |
+
]
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"execution_count": null,
|
136 |
+
"metadata": {},
|
137 |
+
"outputs": [],
|
138 |
+
"source": [
|
139 |
+
"if True:\n",
|
140 |
+
" train(mamba_bit, 5000, 9, 8*128)\n"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"cell_type": "code",
|
145 |
+
"execution_count": null,
|
146 |
+
"metadata": {},
|
147 |
+
"outputs": [],
|
148 |
+
"source": [
|
149 |
+
"torch.save(mamba_bit.state_dict(), \"mamba_bit.bin\")"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "code",
|
154 |
+
"execution_count": null,
|
155 |
+
"metadata": {},
|
156 |
+
"outputs": [],
|
157 |
+
"source": [
|
158 |
+
"# TEST\n",
|
159 |
+
"@torch.no_grad()\n",
|
160 |
+
"def test(prompt: str, chars=10):\n",
|
161 |
+
" x0 = decode_bits(prompt).cuda()[None]\n",
|
162 |
+
" x = x0.clone()\n",
|
163 |
+
" process = chars * 8\n",
|
164 |
+
" for _ in tqdm(range(process)):\n",
|
165 |
+
" y = mamba_bit(x)\n",
|
166 |
+
" new = y[:, -1:].argmax(-1)\n",
|
167 |
+
" x = torch.cat((x, new), 1) \n",
|
168 |
+
" return encode_bits(x)\n",
|
169 |
+
"\n",
|
170 |
+
" \n",
|
171 |
+
"print(test(\"FIRST CIT\", chars=10))"
|
172 |
+
]
|
173 |
+
}
|
174 |
+
],
|
175 |
+
"metadata": {
|
176 |
+
"kernelspec": {
|
177 |
+
"display_name": "sd",
|
178 |
+
"language": "python",
|
179 |
+
"name": "sd"
|
180 |
+
},
|
181 |
+
"language_info": {
|
182 |
+
"codemirror_mode": {
|
183 |
+
"name": "ipython",
|
184 |
+
"version": 3
|
185 |
+
},
|
186 |
+
"file_extension": ".py",
|
187 |
+
"mimetype": "text/x-python",
|
188 |
+
"name": "python",
|
189 |
+
"nbconvert_exporter": "python",
|
190 |
+
"pygments_lexer": "ipython3",
|
191 |
+
"version": "3.11.8"
|
192 |
+
}
|
193 |
+
},
|
194 |
+
"nbformat": 4,
|
195 |
+
"nbformat_minor": 2
|
196 |
+
}
|