Upload folder using huggingface_hub
Browse files- README.md +160 -0
- bit_trainer.py +199 -0
- byte_trainer.py +176 -0
- dibit_trainer.py +200 -0
- purebit_trainer.py +275 -0
README.md
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Binary Transformers: Learning Language from Raw Binary
|
| 2 |
+
|
| 3 |
+
**Zero-tokenization transformers that learn directly from network bytes, bits, and beyond.**
|
| 4 |
+
|
| 5 |
+
This repository contains four novel transformer architectures exploring the limits of minimal vocabulary learning:
|
| 6 |
+
|
| 7 |
+
| Model | Vocab | Input | Weights | Description |
|
| 8 |
+
|-------|-------|-------|---------|-------------|
|
| 9 |
+
| **Byte-level** | 256 | bytes (0x00-0xFF) | real | One token per byte value |
|
| 10 |
+
| **Bit-level** | 2 | bits (0, 1) | real | Pure binary, 8 tokens per byte |
|
| 11 |
+
| **Dibit** | 4 | dibits (00,01,10,11) | real | 2-bit tokens, 4 per byte |
|
| 12 |
+
| **Pure Binary** | 2 | bits (0, 1) | **binary (-1/+1)** | BITS ALL THE WAY DOWN |
|
| 13 |
+
|
| 14 |
+
## Why?
|
| 15 |
+
|
| 16 |
+
Traditional LLMs use tokenizers (BPE, SentencePiece) with 32k-256k vocabulary. This creates:
|
| 17 |
+
- Tokenizer overhead and complexity
|
| 18 |
+
- Language/domain bias baked into vocabulary
|
| 19 |
+
- Preprocessing bottleneck
|
| 20 |
+
|
| 21 |
+
**What if we eliminated tokenization entirely?**
|
| 22 |
+
|
| 23 |
+
These models learn directly from raw binary data - no tokenizer, no preprocessing, just bytes flowing into neural networks. The ultimate goal: **wire-speed learning** where models absorb network traffic in real-time.
|
| 24 |
+
|
| 25 |
+
## Results
|
| 26 |
+
|
| 27 |
+
### Byte-Level (vocab=256)
|
| 28 |
+
```
|
| 29 |
+
Data: 350KB web crawl
|
| 30 |
+
BPB: 4.68 (vs 8.0 random = 41% compression)
|
| 31 |
+
Speed: 8.7 KB/s learning rate
|
| 32 |
+
```
|
| 33 |
+
Learns HTML structure, XML tags, timestamps from raw bytes.
|
| 34 |
+
|
| 35 |
+
### Bit-Level (vocab=2)
|
| 36 |
+
```
|
| 37 |
+
Data: 550KB
|
| 38 |
+
Entropy: 1.008 bit/bit (vs 1.0 random)
|
| 39 |
+
Speed: 0.7 KB/s
|
| 40 |
+
```
|
| 41 |
+
Pure binary learning - discovers byte boundaries and ASCII from 0s and 1s.
|
| 42 |
+
|
| 43 |
+
### Dibit (vocab=4: 00,01,10,11)
|
| 44 |
+
```
|
| 45 |
+
Data: 37KB
|
| 46 |
+
BPB: 7.70 (vs 8.0 random = 3.7% compression)
|
| 47 |
+
Speed: 0.26 KB/s
|
| 48 |
+
```
|
| 49 |
+
2-bit tokens provide 2x context efficiency vs bit-level.
|
| 50 |
+
|
| 51 |
+
### Pure Binary (vocab=2, binary weights)
|
| 52 |
+
```
|
| 53 |
+
Data: 37KB
|
| 54 |
+
Entropy: 1.027 bit/bit
|
| 55 |
+
Binary params: 99.8%
|
| 56 |
+
```
|
| 57 |
+
**BITS ALL THE WAY DOWN** - input bits, binary weights, output bits. On specialized hardware, this enables XNOR+popcount operations instead of multiply-accumulate.
|
| 58 |
+
|
| 59 |
+
## Architecture
|
| 60 |
+
|
| 61 |
+
All models use standard transformer architecture with:
|
| 62 |
+
- Causal self-attention
|
| 63 |
+
- GELU activation
|
| 64 |
+
- LayerNorm
|
| 65 |
+
- AdamW optimizer
|
| 66 |
+
- Straight-Through Estimator (STE) for binary weight gradients
|
| 67 |
+
|
| 68 |
+
### Key Innovation: Online Learning
|
| 69 |
+
|
| 70 |
+
Unlike traditional batch training, these models learn from streaming data:
|
| 71 |
+
- Micro-batches (32-512 tokens)
|
| 72 |
+
- Single-pass, no data curation
|
| 73 |
+
- Real-time network stream compatible
|
| 74 |
+
|
| 75 |
+
## Usage
|
| 76 |
+
|
| 77 |
+
### Byte-Level
|
| 78 |
+
```bash
|
| 79 |
+
# Pipe any data source
|
| 80 |
+
cat data.bin | python byte_trainer.py
|
| 81 |
+
curl -s http://example.com | python byte_trainer.py
|
| 82 |
+
zcat crawl.jsonl.gz | python byte_trainer.py
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
### Bit-Level
|
| 86 |
+
```bash
|
| 87 |
+
cat data.bin | python bit_trainer.py
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Dibit (2-bit tokens)
|
| 91 |
+
```bash
|
| 92 |
+
cat data.bin | python dibit_trainer.py
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
### Pure Binary (binary weights)
|
| 96 |
+
```bash
|
| 97 |
+
cat data.bin | python purebit_trainer.py
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
## Configuration
|
| 101 |
+
|
| 102 |
+
Edit the CONFIG dict in each trainer:
|
| 103 |
+
|
| 104 |
+
```python
|
| 105 |
+
CONFIG = {
|
| 106 |
+
"d": 256, # embedding dimension
|
| 107 |
+
"layers": 6, # transformer layers
|
| 108 |
+
"heads": 8, # attention heads
|
| 109 |
+
"vocab": 2, # vocabulary size
|
| 110 |
+
"ctx": 2048, # context length
|
| 111 |
+
}
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## Files
|
| 115 |
+
|
| 116 |
+
```
|
| 117 |
+
byte_trainer.py # Vocab=256, one token per byte
|
| 118 |
+
bit_trainer.py # Vocab=2, pure bits
|
| 119 |
+
dibit_trainer.py # Vocab=4, 2-bit tokens (00,01,10,11)
|
| 120 |
+
purebit_trainer.py # Vocab=2 + binary weights (-1/+1)
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
## Insights
|
| 124 |
+
|
| 125 |
+
1. **Byte-level is sweet spot** - 256 vocab captures ASCII structure efficiently while eliminating tokenizer overhead
|
| 126 |
+
|
| 127 |
+
2. **Bit-level works but slow** - 8x longer sequences mean 8x less context per forward pass
|
| 128 |
+
|
| 129 |
+
3. **Dibit balances** - 2-bit tokens give 2x context vs bit-level while staying "pure binary"
|
| 130 |
+
|
| 131 |
+
4. **Binary weights viable** - 99.8% binary params learn almost as well as real weights, enabling massive hardware speedups
|
| 132 |
+
|
| 133 |
+
5. **HTML is natural SFT** - Web data contains instruction-following patterns: `<h3>Question</h3><p>Answer`, `<dt>Term</dt><dd>Definition</dd>`, JSON Q&A
|
| 134 |
+
|
| 135 |
+
## Future Work
|
| 136 |
+
|
| 137 |
+
- Scale to billions of parameters
|
| 138 |
+
- Custom CUDA kernels for binary ops (XNOR + popcount)
|
| 139 |
+
- FPGA/ASIC implementation for true wire-speed learning
|
| 140 |
+
- Hierarchical binary models (bit → byte → word emergence)
|
| 141 |
+
|
| 142 |
+
## Citation
|
| 143 |
+
|
| 144 |
+
```bibtex
|
| 145 |
+
@misc{opentransformer2026binary,
|
| 146 |
+
title={Binary Transformers: Learning Language from Raw Binary},
|
| 147 |
+
author={OpenTransformer},
|
| 148 |
+
year={2026},
|
| 149 |
+
publisher={HuggingFace},
|
| 150 |
+
url={https://huggingface.co/OpenTransformer/binary-transformers}
|
| 151 |
+
}
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
## License
|
| 155 |
+
|
| 156 |
+
MIT
|
| 157 |
+
|
| 158 |
+
## Acknowledgments
|
| 159 |
+
|
| 160 |
+
Built with PyTorch. Trained on vast.ai GPU instances. Part of the AGILLM research project.
|
bit_trainer.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
BIT-LEVEL TRANSFORMER - The Ultimate Zero-Overhead Model
|
| 4 |
+
Vocab = 2 (just 0 and 1)
|
| 5 |
+
No tokenization. No bytes. Pure binary.
|
| 6 |
+
|
| 7 |
+
Each byte becomes 8 tokens (bits).
|
| 8 |
+
Model learns ALL structure from raw bits.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
import math
|
| 13 |
+
import time
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from collections import deque
|
| 18 |
+
|
| 19 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 21 |
+
|
| 22 |
+
# BIT-LEVEL CONFIG - ABSOLUTE UNIT
|
| 23 |
+
CONFIG = {
|
| 24 |
+
"d": 768, # GPT-2 small size
|
| 25 |
+
"layers": 12, # DEEP for bit pattern learning
|
| 26 |
+
"heads": 12,
|
| 27 |
+
"vocab": 2, # JUST 0 AND 1!
|
| 28 |
+
"ctx": 4096, # 512 bytes of context
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
LR = 3e-4 # learning rate
|
| 32 |
+
UPDATE_EVERY = 2048 # bits between updates (256 bytes worth) - BIGGER BATCHES
|
| 33 |
+
PRINT_EVERY = 100000 # bits
|
| 34 |
+
|
| 35 |
+
class BitAttention(nn.Module):
|
| 36 |
+
def __init__(self, d, h):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.h, self.dk = h, d // h
|
| 39 |
+
self.qkv = nn.Linear(d, 3 * d, bias=False)
|
| 40 |
+
self.proj = nn.Linear(d, d, bias=False)
|
| 41 |
+
|
| 42 |
+
def forward(self, x, mask=None):
|
| 43 |
+
B, N, D = x.shape
|
| 44 |
+
qkv = self.qkv(x).view(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
|
| 45 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 46 |
+
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
|
| 47 |
+
if mask is not None:
|
| 48 |
+
att = att + mask
|
| 49 |
+
return self.proj((F.softmax(att, -1) @ v).transpose(1, 2).reshape(B, N, D))
|
| 50 |
+
|
| 51 |
+
class BitBlock(nn.Module):
|
| 52 |
+
def __init__(self, d, h):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
|
| 55 |
+
self.attn = BitAttention(d, h)
|
| 56 |
+
self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d))
|
| 57 |
+
|
| 58 |
+
def forward(self, x, mask):
|
| 59 |
+
x = x + self.attn(self.ln1(x), mask)
|
| 60 |
+
return x + self.ff(self.ln2(x))
|
| 61 |
+
|
| 62 |
+
class BitTransformer(nn.Module):
|
| 63 |
+
"""Transformer with vocab=2 (just 0 and 1)"""
|
| 64 |
+
def __init__(self, cfg):
|
| 65 |
+
super().__init__()
|
| 66 |
+
d, L, h = cfg["d"], cfg["layers"], cfg["heads"]
|
| 67 |
+
self.emb = nn.Embedding(2, d) # ONLY 2 EMBEDDINGS!
|
| 68 |
+
self.blocks = nn.ModuleList([BitBlock(d, h) for _ in range(L)])
|
| 69 |
+
self.ln = nn.LayerNorm(d)
|
| 70 |
+
self.head = nn.Linear(d, 2, bias=False) # predict 0 or 1
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
B, N = x.shape
|
| 74 |
+
mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9
|
| 75 |
+
h = self.emb(x)
|
| 76 |
+
for block in self.blocks:
|
| 77 |
+
h = block(h, mask)
|
| 78 |
+
return self.head(self.ln(h))
|
| 79 |
+
|
| 80 |
+
def count_params(self):
|
| 81 |
+
return sum(p.numel() for p in self.parameters())
|
| 82 |
+
|
| 83 |
+
def byte_to_bits(byte_val):
|
| 84 |
+
"""Convert byte to 8 bits (MSB first)"""
|
| 85 |
+
return [(byte_val >> (7 - i)) & 1 for i in range(8)]
|
| 86 |
+
|
| 87 |
+
def bits_to_byte(bits):
|
| 88 |
+
"""Convert 8 bits back to byte"""
|
| 89 |
+
val = 0
|
| 90 |
+
for i, b in enumerate(bits[:8]):
|
| 91 |
+
val |= (b << (7 - i))
|
| 92 |
+
return val
|
| 93 |
+
|
| 94 |
+
class BitTrainer:
|
| 95 |
+
def __init__(self, model, lr=LR):
|
| 96 |
+
self.model = model.to(DEVICE)
|
| 97 |
+
self.opt = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 98 |
+
self.ctx_size = CONFIG["ctx"]
|
| 99 |
+
self.buffer = deque(maxlen=self.ctx_size + 1)
|
| 100 |
+
|
| 101 |
+
self.bits_seen = 0
|
| 102 |
+
self.bytes_seen = 0
|
| 103 |
+
self.total_loss = 0.0
|
| 104 |
+
self.updates = 0
|
| 105 |
+
self.start_time = time.time()
|
| 106 |
+
|
| 107 |
+
def ingest_byte(self, byte_val):
|
| 108 |
+
"""Convert byte to 8 bits and absorb"""
|
| 109 |
+
bits = byte_to_bits(byte_val)
|
| 110 |
+
for bit in bits:
|
| 111 |
+
self.buffer.append(bit)
|
| 112 |
+
self.bits_seen += 1
|
| 113 |
+
|
| 114 |
+
if len(self.buffer) >= UPDATE_EVERY + 1 and self.bits_seen % UPDATE_EVERY == 0:
|
| 115 |
+
self._update()
|
| 116 |
+
|
| 117 |
+
self.bytes_seen += 1
|
| 118 |
+
|
| 119 |
+
if self.bits_seen % PRINT_EVERY == 0:
|
| 120 |
+
self._print_stats()
|
| 121 |
+
|
| 122 |
+
if self.bytes_seen % 500000 == 0 and self.bytes_seen > 0:
|
| 123 |
+
self._save()
|
| 124 |
+
|
| 125 |
+
def _update(self):
|
| 126 |
+
bits = list(self.buffer)
|
| 127 |
+
x = torch.tensor(bits[:-1], device=DEVICE, dtype=torch.long).unsqueeze(0)
|
| 128 |
+
y = torch.tensor(bits[1:], device=DEVICE, dtype=torch.long).unsqueeze(0)
|
| 129 |
+
|
| 130 |
+
self.model.train()
|
| 131 |
+
logits = self.model(x)
|
| 132 |
+
loss = F.cross_entropy(
|
| 133 |
+
logits[:, -UPDATE_EVERY:].reshape(-1, 2),
|
| 134 |
+
y[:, -UPDATE_EVERY:].reshape(-1)
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
self.opt.zero_grad()
|
| 138 |
+
loss.backward()
|
| 139 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 140 |
+
self.opt.step()
|
| 141 |
+
|
| 142 |
+
self.total_loss += loss.item()
|
| 143 |
+
self.updates += 1
|
| 144 |
+
|
| 145 |
+
def _print_stats(self):
|
| 146 |
+
elapsed = time.time() - self.start_time
|
| 147 |
+
bits_per_sec = self.bits_seen / elapsed if elapsed > 0 else 0
|
| 148 |
+
bytes_per_sec = self.bytes_seen / elapsed if elapsed > 0 else 0
|
| 149 |
+
avg_loss = self.total_loss / max(1, self.updates)
|
| 150 |
+
|
| 151 |
+
# For bits: random is 1.0 (coin flip), lower = learning
|
| 152 |
+
# Entropy in bits per bit
|
| 153 |
+
entropy = avg_loss / math.log(2)
|
| 154 |
+
compression = (1.0 - entropy) * 100 # % compression vs random
|
| 155 |
+
|
| 156 |
+
print(f"[{elapsed:.0f}s] {self.bytes_seen/1000:.1f}KB | {bytes_per_sec/1000:.1f} KB/s | "
|
| 157 |
+
f"loss={avg_loss:.4f} | entropy={entropy:.3f} bit/bit | "
|
| 158 |
+
f"compression={compression:.1f}%", flush=True)
|
| 159 |
+
|
| 160 |
+
def _save(self):
|
| 161 |
+
avg_loss = self.total_loss / max(1, self.updates)
|
| 162 |
+
kb = self.bytes_seen // 1000
|
| 163 |
+
ckpt = {
|
| 164 |
+
"model": self.model.state_dict(),
|
| 165 |
+
"bits": self.bits_seen,
|
| 166 |
+
"bytes": self.bytes_seen,
|
| 167 |
+
"loss": avg_loss,
|
| 168 |
+
}
|
| 169 |
+
torch.save(ckpt, f"/workspace/bit_ckpt_{kb}kb.pt")
|
| 170 |
+
print(f"[SAVED] bit_ckpt_{kb}kb.pt", flush=True)
|
| 171 |
+
|
| 172 |
+
def main():
|
| 173 |
+
print(f"BIT-LEVEL TRANSFORMER - Vocab = 2 (just 0 and 1)", flush=True)
|
| 174 |
+
print(f"Config: {CONFIG}", flush=True)
|
| 175 |
+
print(f"Device: {DEVICE}", flush=True)
|
| 176 |
+
|
| 177 |
+
model = BitTransformer(CONFIG)
|
| 178 |
+
params = model.count_params()
|
| 179 |
+
print(f"Parameters: {params:,} ({params/1e6:.2f}M)", flush=True)
|
| 180 |
+
print(f"Vocab: 2 (literally just 0 and 1)", flush=True)
|
| 181 |
+
print(f"Each byte = 8 bit tokens", flush=True)
|
| 182 |
+
|
| 183 |
+
trainer = BitTrainer(model)
|
| 184 |
+
|
| 185 |
+
print(f"Listening for bytes (FAST batch mode)...", flush=True)
|
| 186 |
+
|
| 187 |
+
# Read in large chunks for speed
|
| 188 |
+
CHUNK_SIZE = 8192 # 8KB chunks = 65536 bits
|
| 189 |
+
while True:
|
| 190 |
+
chunk = sys.stdin.buffer.read(CHUNK_SIZE)
|
| 191 |
+
if not chunk:
|
| 192 |
+
break
|
| 193 |
+
for byte in chunk:
|
| 194 |
+
trainer.ingest_byte(byte)
|
| 195 |
+
|
| 196 |
+
print(f"Stream ended. Total: {trainer.bytes_seen:,} bytes = {trainer.bits_seen:,} bits", flush=True)
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
main()
|
byte_trainer.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
BINARY TRANSFORMER - Raw network bytes → neural network
|
| 4 |
+
No tokenizer. No preprocessing. Just bytes.
|
| 5 |
+
|
| 6 |
+
Vocab = 256 (one token per byte value 0x00-0xFF)
|
| 7 |
+
Input: Raw bytes from network stream via stdin
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import sys
|
| 11 |
+
import math
|
| 12 |
+
import time
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from collections import deque
|
| 17 |
+
|
| 18 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 20 |
+
|
| 21 |
+
# Binary model config - TINY for speed
|
| 22 |
+
CONFIG = {
|
| 23 |
+
"d": 128, # smaller embedding
|
| 24 |
+
"layers": 3, # fewer layers
|
| 25 |
+
"heads": 4,
|
| 26 |
+
"vocab": 256, # ONE TOKEN PER BYTE
|
| 27 |
+
"ctx": 1024, # longer context (bytes are fine-grained)
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
LR = 3e-4
|
| 31 |
+
UPDATE_EVERY = 64 # bytes between updates
|
| 32 |
+
PRINT_EVERY = 50000 # bytes between stats
|
| 33 |
+
|
| 34 |
+
class ByteAttention(nn.Module):
|
| 35 |
+
def __init__(self, d, h):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.h, self.dk = h, d // h
|
| 38 |
+
self.qkv = nn.Linear(d, 3 * d, bias=False)
|
| 39 |
+
self.proj = nn.Linear(d, d, bias=False)
|
| 40 |
+
|
| 41 |
+
def forward(self, x, mask=None):
|
| 42 |
+
B, N, D = x.shape
|
| 43 |
+
qkv = self.qkv(x).view(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
|
| 44 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 45 |
+
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
|
| 46 |
+
if mask is not None:
|
| 47 |
+
att = att + mask
|
| 48 |
+
return self.proj((F.softmax(att, -1) @ v).transpose(1, 2).reshape(B, N, D))
|
| 49 |
+
|
| 50 |
+
class ByteBlock(nn.Module):
|
| 51 |
+
def __init__(self, d, h):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
|
| 54 |
+
self.attn = ByteAttention(d, h)
|
| 55 |
+
self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d))
|
| 56 |
+
|
| 57 |
+
def forward(self, x, mask):
|
| 58 |
+
x = x + self.attn(self.ln1(x), mask)
|
| 59 |
+
return x + self.ff(self.ln2(x))
|
| 60 |
+
|
| 61 |
+
class BinaryTransformer(nn.Module):
|
| 62 |
+
def __init__(self, cfg):
|
| 63 |
+
super().__init__()
|
| 64 |
+
d, L, h, V = cfg["d"], cfg["layers"], cfg["heads"], cfg["vocab"]
|
| 65 |
+
self.emb = nn.Embedding(V, d) # 256 embeddings, one per byte
|
| 66 |
+
self.blocks = nn.ModuleList([ByteBlock(d, h) for _ in range(L)])
|
| 67 |
+
self.ln = nn.LayerNorm(d)
|
| 68 |
+
self.head = nn.Linear(d, V, bias=False)
|
| 69 |
+
self.head.weight = self.emb.weight # tie weights
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
B, N = x.shape
|
| 73 |
+
mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9
|
| 74 |
+
h = self.emb(x)
|
| 75 |
+
for block in self.blocks:
|
| 76 |
+
h = block(h, mask)
|
| 77 |
+
return self.head(self.ln(h))
|
| 78 |
+
|
| 79 |
+
def count_params(self):
|
| 80 |
+
return sum(p.numel() for p in self.parameters())
|
| 81 |
+
|
| 82 |
+
class BinaryTrainer:
|
| 83 |
+
def __init__(self, model, lr=LR):
|
| 84 |
+
self.model = model.to(DEVICE)
|
| 85 |
+
self.opt = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 86 |
+
self.ctx_size = CONFIG["ctx"]
|
| 87 |
+
self.buffer = deque(maxlen=self.ctx_size + 1)
|
| 88 |
+
|
| 89 |
+
self.bytes_seen = 0
|
| 90 |
+
self.total_loss = 0.0
|
| 91 |
+
self.updates = 0
|
| 92 |
+
self.start_time = time.time()
|
| 93 |
+
|
| 94 |
+
def ingest_byte(self, byte_val):
|
| 95 |
+
"""Absorb a single byte (0-255)"""
|
| 96 |
+
self.buffer.append(byte_val)
|
| 97 |
+
self.bytes_seen += 1
|
| 98 |
+
|
| 99 |
+
if len(self.buffer) >= UPDATE_EVERY + 1 and self.bytes_seen % UPDATE_EVERY == 0:
|
| 100 |
+
self._update()
|
| 101 |
+
|
| 102 |
+
if self.bytes_seen % PRINT_EVERY == 0:
|
| 103 |
+
self._print_stats()
|
| 104 |
+
|
| 105 |
+
# Save checkpoint every 500k bytes
|
| 106 |
+
if self.bytes_seen % 500000 == 0 and self.bytes_seen > 0:
|
| 107 |
+
self._save()
|
| 108 |
+
|
| 109 |
+
def _update(self):
|
| 110 |
+
tokens = list(self.buffer)
|
| 111 |
+
x = torch.tensor(tokens[:-1], device=DEVICE, dtype=torch.long).unsqueeze(0)
|
| 112 |
+
y = torch.tensor(tokens[1:], device=DEVICE, dtype=torch.long).unsqueeze(0)
|
| 113 |
+
|
| 114 |
+
self.model.train()
|
| 115 |
+
logits = self.model(x)
|
| 116 |
+
loss = F.cross_entropy(
|
| 117 |
+
logits[:, -UPDATE_EVERY:].reshape(-1, 256),
|
| 118 |
+
y[:, -UPDATE_EVERY:].reshape(-1)
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self.opt.zero_grad()
|
| 122 |
+
loss.backward()
|
| 123 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 124 |
+
self.opt.step()
|
| 125 |
+
|
| 126 |
+
self.total_loss += loss.item()
|
| 127 |
+
self.updates += 1
|
| 128 |
+
|
| 129 |
+
def _print_stats(self):
|
| 130 |
+
elapsed = time.time() - self.start_time
|
| 131 |
+
rate = self.bytes_seen / elapsed if elapsed > 0 else 0
|
| 132 |
+
avg_loss = self.total_loss / max(1, self.updates)
|
| 133 |
+
mb = self.bytes_seen / 1_000_000
|
| 134 |
+
|
| 135 |
+
# Bits per byte (compression metric) - log2(256)=8 is random, lower is learning
|
| 136 |
+
bpb = avg_loss / math.log(2)
|
| 137 |
+
|
| 138 |
+
print(f"[{elapsed:.0f}s] {mb:.2f}MB | {rate/1000:.1f} KB/s | "
|
| 139 |
+
f"loss={avg_loss:.3f} | bpb={bpb:.2f} | updates={self.updates}", flush=True)
|
| 140 |
+
|
| 141 |
+
def _save(self):
|
| 142 |
+
avg_loss = self.total_loss / max(1, self.updates)
|
| 143 |
+
mb = self.bytes_seen // 1_000_000
|
| 144 |
+
ckpt = {
|
| 145 |
+
"model": self.model.state_dict(),
|
| 146 |
+
"bytes": self.bytes_seen,
|
| 147 |
+
"loss": avg_loss,
|
| 148 |
+
}
|
| 149 |
+
torch.save(ckpt, f"byte_ckpt_{mb}mb.pt")
|
| 150 |
+
print(f"[SAVED] {mb}MB checkpoint", flush=True)
|
| 151 |
+
|
| 152 |
+
def main():
|
| 153 |
+
print(f"BINARY TRANSFORMER - Raw bytes learning", flush=True)
|
| 154 |
+
print(f"Config: {CONFIG}", flush=True)
|
| 155 |
+
print(f"Device: {DEVICE}", flush=True)
|
| 156 |
+
|
| 157 |
+
model = BinaryTransformer(CONFIG)
|
| 158 |
+
params = model.count_params()
|
| 159 |
+
print(f"Parameters: {params:,} ({params/1e6:.1f}M)", flush=True)
|
| 160 |
+
print(f"Vocab: 256 (one per byte)", flush=True)
|
| 161 |
+
|
| 162 |
+
trainer = BinaryTrainer(model)
|
| 163 |
+
|
| 164 |
+
print(f"Listening for raw bytes on stdin...", flush=True)
|
| 165 |
+
|
| 166 |
+
# Read raw bytes from stdin
|
| 167 |
+
while True:
|
| 168 |
+
byte = sys.stdin.buffer.read(1)
|
| 169 |
+
if not byte:
|
| 170 |
+
break
|
| 171 |
+
trainer.ingest_byte(byte[0])
|
| 172 |
+
|
| 173 |
+
print(f"Stream ended. Total bytes: {trainer.bytes_seen:,}", flush=True)
|
| 174 |
+
|
| 175 |
+
if __name__ == "__main__":
|
| 176 |
+
main()
|
dibit_trainer.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
DIBIT TRANSFORMER - 2-bit tokens
|
| 4 |
+
Vocab = 4 (00, 01, 10, 11)
|
| 5 |
+
Each byte = 4 tokens (vs 8 for bit-level)
|
| 6 |
+
Better context efficiency while still pure binary!
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
import math
|
| 11 |
+
import time
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from collections import deque
|
| 16 |
+
|
| 17 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 19 |
+
|
| 20 |
+
# DIBIT CONFIG - 2-bit tokens
|
| 21 |
+
CONFIG = {
|
| 22 |
+
"d": 512, # good size
|
| 23 |
+
"layers": 12,
|
| 24 |
+
"heads": 8,
|
| 25 |
+
"vocab": 4, # 00, 01, 10, 11
|
| 26 |
+
"ctx": 4096, # 1024 bytes of context (2x more than bit-level!)
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
LR = 3e-4
|
| 30 |
+
UPDATE_EVERY = 512 # dibits between updates (128 bytes worth)
|
| 31 |
+
PRINT_EVERY = 50000 # dibits
|
| 32 |
+
|
| 33 |
+
class DibitAttention(nn.Module):
|
| 34 |
+
def __init__(self, d, h):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.h, self.dk = h, d // h
|
| 37 |
+
self.qkv = nn.Linear(d, 3 * d, bias=False)
|
| 38 |
+
self.proj = nn.Linear(d, d, bias=False)
|
| 39 |
+
|
| 40 |
+
def forward(self, x, mask=None):
|
| 41 |
+
B, N, D = x.shape
|
| 42 |
+
qkv = self.qkv(x).view(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
|
| 43 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 44 |
+
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
|
| 45 |
+
if mask is not None:
|
| 46 |
+
att = att + mask
|
| 47 |
+
return self.proj((F.softmax(att, -1) @ v).transpose(1, 2).reshape(B, N, D))
|
| 48 |
+
|
| 49 |
+
class DibitBlock(nn.Module):
|
| 50 |
+
def __init__(self, d, h):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
|
| 53 |
+
self.attn = DibitAttention(d, h)
|
| 54 |
+
self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d))
|
| 55 |
+
|
| 56 |
+
def forward(self, x, mask):
|
| 57 |
+
x = x + self.attn(self.ln1(x), mask)
|
| 58 |
+
return x + self.ff(self.ln2(x))
|
| 59 |
+
|
| 60 |
+
class DibitTransformer(nn.Module):
|
| 61 |
+
"""Transformer with vocab=4 (00, 01, 10, 11)"""
|
| 62 |
+
def __init__(self, cfg):
|
| 63 |
+
super().__init__()
|
| 64 |
+
d, L, h = cfg["d"], cfg["layers"], cfg["heads"]
|
| 65 |
+
self.emb = nn.Embedding(4, d) # 4 embeddings for dibits
|
| 66 |
+
self.blocks = nn.ModuleList([DibitBlock(d, h) for _ in range(L)])
|
| 67 |
+
self.ln = nn.LayerNorm(d)
|
| 68 |
+
self.head = nn.Linear(d, 4, bias=False) # predict 00, 01, 10, or 11
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
B, N = x.shape
|
| 72 |
+
mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9
|
| 73 |
+
h = self.emb(x)
|
| 74 |
+
for block in self.blocks:
|
| 75 |
+
h = block(h, mask)
|
| 76 |
+
return self.head(self.ln(h))
|
| 77 |
+
|
| 78 |
+
def count_params(self):
|
| 79 |
+
return sum(p.numel() for p in self.parameters())
|
| 80 |
+
|
| 81 |
+
def byte_to_dibits(byte_val):
|
| 82 |
+
"""Convert byte to 4 dibits (2-bit chunks, MSB first)
|
| 83 |
+
e.g., 0b11100100 -> [3, 2, 1, 0] (11, 10, 01, 00)
|
| 84 |
+
"""
|
| 85 |
+
return [
|
| 86 |
+
(byte_val >> 6) & 0b11, # bits 7-6
|
| 87 |
+
(byte_val >> 4) & 0b11, # bits 5-4
|
| 88 |
+
(byte_val >> 2) & 0b11, # bits 3-2
|
| 89 |
+
byte_val & 0b11, # bits 1-0
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
def dibits_to_byte(dibits):
|
| 93 |
+
"""Convert 4 dibits back to byte"""
|
| 94 |
+
return (dibits[0] << 6) | (dibits[1] << 4) | (dibits[2] << 2) | dibits[3]
|
| 95 |
+
|
| 96 |
+
class DibitTrainer:
|
| 97 |
+
def __init__(self, model, lr=LR):
|
| 98 |
+
self.model = model.to(DEVICE)
|
| 99 |
+
self.opt = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 100 |
+
self.ctx_size = CONFIG["ctx"]
|
| 101 |
+
self.buffer = deque(maxlen=self.ctx_size + 1)
|
| 102 |
+
|
| 103 |
+
self.dibits_seen = 0
|
| 104 |
+
self.bytes_seen = 0
|
| 105 |
+
self.total_loss = 0.0
|
| 106 |
+
self.updates = 0
|
| 107 |
+
self.start_time = time.time()
|
| 108 |
+
|
| 109 |
+
def ingest_byte(self, byte_val):
|
| 110 |
+
"""Convert byte to 4 dibits and absorb"""
|
| 111 |
+
dibits = byte_to_dibits(byte_val)
|
| 112 |
+
for dibit in dibits:
|
| 113 |
+
self.buffer.append(dibit)
|
| 114 |
+
self.dibits_seen += 1
|
| 115 |
+
|
| 116 |
+
if len(self.buffer) >= UPDATE_EVERY + 1 and self.dibits_seen % UPDATE_EVERY == 0:
|
| 117 |
+
self._update()
|
| 118 |
+
|
| 119 |
+
self.bytes_seen += 1
|
| 120 |
+
|
| 121 |
+
if self.dibits_seen % PRINT_EVERY == 0:
|
| 122 |
+
self._print_stats()
|
| 123 |
+
|
| 124 |
+
if self.bytes_seen % 500000 == 0 and self.bytes_seen > 0:
|
| 125 |
+
self._save()
|
| 126 |
+
|
| 127 |
+
def _update(self):
|
| 128 |
+
tokens = list(self.buffer)
|
| 129 |
+
x = torch.tensor(tokens[:-1], device=DEVICE, dtype=torch.long).unsqueeze(0)
|
| 130 |
+
y = torch.tensor(tokens[1:], device=DEVICE, dtype=torch.long).unsqueeze(0)
|
| 131 |
+
|
| 132 |
+
self.model.train()
|
| 133 |
+
logits = self.model(x)
|
| 134 |
+
loss = F.cross_entropy(
|
| 135 |
+
logits[:, -UPDATE_EVERY:].reshape(-1, 4),
|
| 136 |
+
y[:, -UPDATE_EVERY:].reshape(-1)
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.opt.zero_grad()
|
| 140 |
+
loss.backward()
|
| 141 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 142 |
+
self.opt.step()
|
| 143 |
+
|
| 144 |
+
self.total_loss += loss.item()
|
| 145 |
+
self.updates += 1
|
| 146 |
+
|
| 147 |
+
def _print_stats(self):
|
| 148 |
+
elapsed = time.time() - self.start_time
|
| 149 |
+
bytes_per_sec = self.bytes_seen / elapsed if elapsed > 0 else 0
|
| 150 |
+
avg_loss = self.total_loss / max(1, self.updates)
|
| 151 |
+
|
| 152 |
+
# For dibits: random is log(4)/log(2) = 2.0 bits per dibit
|
| 153 |
+
# Entropy in bits per dibit
|
| 154 |
+
entropy_per_dibit = avg_loss / math.log(2)
|
| 155 |
+
# Convert to bits per byte (4 dibits per byte)
|
| 156 |
+
bpb = entropy_per_dibit * 4
|
| 157 |
+
# Random byte = 8 bits, so compression vs random
|
| 158 |
+
compression = (1.0 - bpb/8) * 100
|
| 159 |
+
|
| 160 |
+
print(f"[{elapsed:.0f}s] {self.bytes_seen/1000:.1f}KB | {bytes_per_sec/1000:.2f} KB/s | "
|
| 161 |
+
f"loss={avg_loss:.4f} | bpb={bpb:.2f} | compression={compression:.1f}%", flush=True)
|
| 162 |
+
|
| 163 |
+
def _save(self):
|
| 164 |
+
avg_loss = self.total_loss / max(1, self.updates)
|
| 165 |
+
kb = self.bytes_seen // 1000
|
| 166 |
+
ckpt = {
|
| 167 |
+
"model": self.model.state_dict(),
|
| 168 |
+
"dibits": self.dibits_seen,
|
| 169 |
+
"bytes": self.bytes_seen,
|
| 170 |
+
"loss": avg_loss,
|
| 171 |
+
}
|
| 172 |
+
torch.save(ckpt, f"/workspace/dibit_ckpt_{kb}kb.pt")
|
| 173 |
+
print(f"[SAVED] dibit_ckpt_{kb}kb.pt", flush=True)
|
| 174 |
+
|
| 175 |
+
def main():
|
| 176 |
+
print(f"DIBIT TRANSFORMER - Vocab = 4 (00, 01, 10, 11)", flush=True)
|
| 177 |
+
print(f"Config: {CONFIG}", flush=True)
|
| 178 |
+
print(f"Device: {DEVICE}", flush=True)
|
| 179 |
+
|
| 180 |
+
model = DibitTransformer(CONFIG)
|
| 181 |
+
params = model.count_params()
|
| 182 |
+
print(f"Parameters: {params:,} ({params/1e6:.2f}M)", flush=True)
|
| 183 |
+
print(f"Vocab: 4 (2-bit tokens: 00, 01, 10, 11)", flush=True)
|
| 184 |
+
print(f"Each byte = 4 dibit tokens", flush=True)
|
| 185 |
+
print(f"Context: {CONFIG['ctx']} dibits = {CONFIG['ctx']//4} bytes", flush=True)
|
| 186 |
+
|
| 187 |
+
trainer = DibitTrainer(model)
|
| 188 |
+
|
| 189 |
+
print(f"Listening for bytes (converting to dibits)...", flush=True)
|
| 190 |
+
|
| 191 |
+
while True:
|
| 192 |
+
byte = sys.stdin.buffer.read(1)
|
| 193 |
+
if not byte:
|
| 194 |
+
break
|
| 195 |
+
trainer.ingest_byte(byte[0])
|
| 196 |
+
|
| 197 |
+
print(f"Stream ended. Total: {trainer.bytes_seen:,} bytes = {trainer.dibits_seen:,} dibits", flush=True)
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
main()
|
purebit_trainer.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
PURE BINARY TRANSFORMER - BITS ALL THE WAY DOWN
|
| 4 |
+
- Vocab = 2 (0 and 1)
|
| 5 |
+
- Weights = binary (-1 or +1, stored as bits)
|
| 6 |
+
- Activations = binary where possible
|
| 7 |
+
|
| 8 |
+
Uses Straight-Through Estimator (STE) for gradients.
|
| 9 |
+
XNOR + popcount for matmul = insanely fast on hardware.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
import math
|
| 14 |
+
import time
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from collections import deque
|
| 19 |
+
|
| 20 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
|
| 22 |
+
# Config for pure binary transformer
|
| 23 |
+
CONFIG = {
|
| 24 |
+
"d": 256, # must be divisible by heads
|
| 25 |
+
"layers": 6,
|
| 26 |
+
"heads": 8,
|
| 27 |
+
"vocab": 2, # 0 and 1
|
| 28 |
+
"ctx": 2048,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
LR = 1e-3
|
| 32 |
+
UPDATE_EVERY = 256
|
| 33 |
+
PRINT_EVERY = 50000
|
| 34 |
+
|
| 35 |
+
# ============== BINARY LAYERS ==============
|
| 36 |
+
|
| 37 |
+
class BinarySign(torch.autograd.Function):
|
| 38 |
+
"""Binarize to -1/+1 with straight-through estimator"""
|
| 39 |
+
@staticmethod
|
| 40 |
+
def forward(ctx, x):
|
| 41 |
+
ctx.save_for_backward(x)
|
| 42 |
+
return x.sign()
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def backward(ctx, grad_output):
|
| 46 |
+
x, = ctx.saved_tensors
|
| 47 |
+
# STE: pass gradient through if |x| <= 1
|
| 48 |
+
grad_input = grad_output.clone()
|
| 49 |
+
grad_input[x.abs() > 1] = 0
|
| 50 |
+
return grad_input
|
| 51 |
+
|
| 52 |
+
def binarize(x):
|
| 53 |
+
return BinarySign.apply(x)
|
| 54 |
+
|
| 55 |
+
class BinaryLinear(nn.Module):
|
| 56 |
+
"""Linear layer with binary weights (-1/+1)"""
|
| 57 |
+
def __init__(self, in_features, out_features, bias=False):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.in_features = in_features
|
| 60 |
+
self.out_features = out_features
|
| 61 |
+
|
| 62 |
+
# Real-valued weights for training, binarized during forward
|
| 63 |
+
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
|
| 64 |
+
if bias:
|
| 65 |
+
self.bias = nn.Parameter(torch.zeros(out_features))
|
| 66 |
+
else:
|
| 67 |
+
self.bias = None
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
# Binarize weights to -1/+1
|
| 71 |
+
binary_weight = binarize(self.weight)
|
| 72 |
+
|
| 73 |
+
# Scale factor for better gradients (from XNOR-Net paper)
|
| 74 |
+
# alpha = mean(|W|)
|
| 75 |
+
alpha = self.weight.abs().mean()
|
| 76 |
+
|
| 77 |
+
out = F.linear(x, binary_weight * alpha, self.bias)
|
| 78 |
+
return out
|
| 79 |
+
|
| 80 |
+
class BinaryAttention(nn.Module):
|
| 81 |
+
"""Attention with binary QKV projections"""
|
| 82 |
+
def __init__(self, d, h):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.h, self.dk = h, d // h
|
| 85 |
+
self.q_proj = BinaryLinear(d, d)
|
| 86 |
+
self.k_proj = BinaryLinear(d, d)
|
| 87 |
+
self.v_proj = BinaryLinear(d, d)
|
| 88 |
+
self.out_proj = BinaryLinear(d, d)
|
| 89 |
+
|
| 90 |
+
def forward(self, x, mask=None):
|
| 91 |
+
B, N, D = x.shape
|
| 92 |
+
|
| 93 |
+
q = self.q_proj(x).view(B, N, self.h, self.dk).transpose(1, 2)
|
| 94 |
+
k = self.k_proj(x).view(B, N, self.h, self.dk).transpose(1, 2)
|
| 95 |
+
v = self.v_proj(x).view(B, N, self.h, self.dk).transpose(1, 2)
|
| 96 |
+
|
| 97 |
+
# Standard attention (values stay real for now)
|
| 98 |
+
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
|
| 99 |
+
if mask is not None:
|
| 100 |
+
att = att + mask
|
| 101 |
+
att = F.softmax(att, dim=-1)
|
| 102 |
+
|
| 103 |
+
out = (att @ v).transpose(1, 2).reshape(B, N, D)
|
| 104 |
+
return self.out_proj(out)
|
| 105 |
+
|
| 106 |
+
class BinaryMLP(nn.Module):
|
| 107 |
+
"""MLP with binary weights"""
|
| 108 |
+
def __init__(self, d):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.fc1 = BinaryLinear(d, d * 4)
|
| 111 |
+
self.fc2 = BinaryLinear(d * 4, d)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
# Binary weights, but ReLU activation (could binarize this too)
|
| 115 |
+
x = F.gelu(self.fc1(x))
|
| 116 |
+
return self.fc2(x)
|
| 117 |
+
|
| 118 |
+
class BinaryBlock(nn.Module):
|
| 119 |
+
def __init__(self, d, h):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.ln1 = nn.LayerNorm(d)
|
| 122 |
+
self.attn = BinaryAttention(d, h)
|
| 123 |
+
self.ln2 = nn.LayerNorm(d)
|
| 124 |
+
self.mlp = BinaryMLP(d)
|
| 125 |
+
|
| 126 |
+
def forward(self, x, mask):
|
| 127 |
+
x = x + self.attn(self.ln1(x), mask)
|
| 128 |
+
return x + self.mlp(self.ln2(x))
|
| 129 |
+
|
| 130 |
+
class PureBinaryTransformer(nn.Module):
|
| 131 |
+
"""
|
| 132 |
+
Transformer where:
|
| 133 |
+
- Input vocab = 2 (bits)
|
| 134 |
+
- All linear weights are binary (-1/+1)
|
| 135 |
+
"""
|
| 136 |
+
def __init__(self, cfg):
|
| 137 |
+
super().__init__()
|
| 138 |
+
d, L, h = cfg["d"], cfg["layers"], cfg["heads"]
|
| 139 |
+
|
| 140 |
+
# Embeddings stay real (only 2 of them anyway)
|
| 141 |
+
self.emb = nn.Embedding(2, d)
|
| 142 |
+
|
| 143 |
+
# Binary blocks
|
| 144 |
+
self.blocks = nn.ModuleList([BinaryBlock(d, h) for _ in range(L)])
|
| 145 |
+
|
| 146 |
+
self.ln = nn.LayerNorm(d)
|
| 147 |
+
self.head = BinaryLinear(d, 2) # Binary output projection too!
|
| 148 |
+
|
| 149 |
+
def forward(self, x):
|
| 150 |
+
B, N = x.shape
|
| 151 |
+
mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9
|
| 152 |
+
|
| 153 |
+
h = self.emb(x)
|
| 154 |
+
for block in self.blocks:
|
| 155 |
+
h = block(h, mask)
|
| 156 |
+
|
| 157 |
+
return self.head(self.ln(h))
|
| 158 |
+
|
| 159 |
+
def count_params(self):
|
| 160 |
+
return sum(p.numel() for p in self.parameters())
|
| 161 |
+
|
| 162 |
+
def count_binary_params(self):
|
| 163 |
+
"""Count params that are binarized"""
|
| 164 |
+
count = 0
|
| 165 |
+
for name, module in self.named_modules():
|
| 166 |
+
if isinstance(module, BinaryLinear):
|
| 167 |
+
count += module.weight.numel()
|
| 168 |
+
return count
|
| 169 |
+
|
| 170 |
+
def byte_to_bits(byte_val):
|
| 171 |
+
return [(byte_val >> (7 - i)) & 1 for i in range(8)]
|
| 172 |
+
|
| 173 |
+
class BinaryTrainer:
|
| 174 |
+
def __init__(self, model, lr=LR):
|
| 175 |
+
self.model = model.to(DEVICE)
|
| 176 |
+
self.opt = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 177 |
+
self.ctx_size = CONFIG["ctx"]
|
| 178 |
+
self.buffer = deque(maxlen=self.ctx_size + 1)
|
| 179 |
+
|
| 180 |
+
self.bits_seen = 0
|
| 181 |
+
self.bytes_seen = 0
|
| 182 |
+
self.total_loss = 0.0
|
| 183 |
+
self.updates = 0
|
| 184 |
+
self.start_time = time.time()
|
| 185 |
+
|
| 186 |
+
def ingest_byte(self, byte_val):
|
| 187 |
+
bits = byte_to_bits(byte_val)
|
| 188 |
+
for bit in bits:
|
| 189 |
+
self.buffer.append(bit)
|
| 190 |
+
self.bits_seen += 1
|
| 191 |
+
|
| 192 |
+
if len(self.buffer) >= UPDATE_EVERY + 1 and self.bits_seen % UPDATE_EVERY == 0:
|
| 193 |
+
self._update()
|
| 194 |
+
|
| 195 |
+
self.bytes_seen += 1
|
| 196 |
+
|
| 197 |
+
if self.bits_seen % PRINT_EVERY == 0:
|
| 198 |
+
self._print_stats()
|
| 199 |
+
|
| 200 |
+
if self.bytes_seen % 500000 == 0 and self.bytes_seen > 0:
|
| 201 |
+
self._save()
|
| 202 |
+
|
| 203 |
+
def _update(self):
|
| 204 |
+
tokens = list(self.buffer)
|
| 205 |
+
x = torch.tensor(tokens[:-1], device=DEVICE, dtype=torch.long).unsqueeze(0)
|
| 206 |
+
y = torch.tensor(tokens[1:], device=DEVICE, dtype=torch.long).unsqueeze(0)
|
| 207 |
+
|
| 208 |
+
self.model.train()
|
| 209 |
+
logits = self.model(x)
|
| 210 |
+
loss = F.cross_entropy(
|
| 211 |
+
logits[:, -UPDATE_EVERY:].reshape(-1, 2),
|
| 212 |
+
y[:, -UPDATE_EVERY:].reshape(-1)
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
self.opt.zero_grad()
|
| 216 |
+
loss.backward()
|
| 217 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 218 |
+
self.opt.step()
|
| 219 |
+
|
| 220 |
+
self.total_loss += loss.item()
|
| 221 |
+
self.updates += 1
|
| 222 |
+
|
| 223 |
+
def _print_stats(self):
|
| 224 |
+
elapsed = time.time() - self.start_time
|
| 225 |
+
bytes_per_sec = self.bytes_seen / elapsed if elapsed > 0 else 0
|
| 226 |
+
avg_loss = self.total_loss / max(1, self.updates)
|
| 227 |
+
|
| 228 |
+
entropy = avg_loss / math.log(2)
|
| 229 |
+
compression = (1.0 - entropy) * 100
|
| 230 |
+
|
| 231 |
+
print(f"[{elapsed:.0f}s] {self.bytes_seen/1000:.1f}KB | {bytes_per_sec/1000:.2f} KB/s | "
|
| 232 |
+
f"loss={avg_loss:.4f} | entropy={entropy:.3f} | compression={compression:.1f}%", flush=True)
|
| 233 |
+
|
| 234 |
+
def _save(self):
|
| 235 |
+
avg_loss = self.total_loss / max(1, self.updates)
|
| 236 |
+
kb = self.bytes_seen // 1000
|
| 237 |
+
ckpt = {
|
| 238 |
+
"model": self.model.state_dict(),
|
| 239 |
+
"bits": self.bits_seen,
|
| 240 |
+
"bytes": self.bytes_seen,
|
| 241 |
+
"loss": avg_loss,
|
| 242 |
+
}
|
| 243 |
+
torch.save(ckpt, f"/workspace/purebit_ckpt_{kb}kb.pt")
|
| 244 |
+
print(f"[SAVED] purebit_ckpt_{kb}kb.pt", flush=True)
|
| 245 |
+
|
| 246 |
+
def main():
|
| 247 |
+
print(f"PURE BINARY TRANSFORMER - BITS ALL THE WAY DOWN", flush=True)
|
| 248 |
+
print(f"Config: {CONFIG}", flush=True)
|
| 249 |
+
print(f"Device: {DEVICE}", flush=True)
|
| 250 |
+
|
| 251 |
+
model = PureBinaryTransformer(CONFIG)
|
| 252 |
+
total_params = model.count_params()
|
| 253 |
+
binary_params = model.count_binary_params()
|
| 254 |
+
|
| 255 |
+
print(f"Total Parameters: {total_params:,} ({total_params/1e6:.2f}M)", flush=True)
|
| 256 |
+
print(f"Binary Parameters: {binary_params:,} ({binary_params/total_params*100:.1f}%)", flush=True)
|
| 257 |
+
print(f"Vocab: 2 (input bits)", flush=True)
|
| 258 |
+
print(f"Weights: BINARY (-1/+1)", flush=True)
|
| 259 |
+
print(f"", flush=True)
|
| 260 |
+
print(f"🔥 BITS IN, BITS WEIGHTS, BITS OUT 🔥", flush=True)
|
| 261 |
+
|
| 262 |
+
trainer = BinaryTrainer(model)
|
| 263 |
+
|
| 264 |
+
print(f"Listening for bytes...", flush=True)
|
| 265 |
+
|
| 266 |
+
while True:
|
| 267 |
+
byte = sys.stdin.buffer.read(1)
|
| 268 |
+
if not byte:
|
| 269 |
+
break
|
| 270 |
+
trainer.ingest_byte(byte[0])
|
| 271 |
+
|
| 272 |
+
print(f"Done. {trainer.bytes_seen:,} bytes = {trainer.bits_seen:,} bits", flush=True)
|
| 273 |
+
|
| 274 |
+
if __name__ == "__main__":
|
| 275 |
+
main()
|