Upload BitPixelLM model artifacts
Browse files- README.md +76 -0
- app.py +310 -0
- best.pt +3 -0
- config.json +7 -0
- generate.py +196 -0
- model/__init__.py +17 -0
- model/bit_pixel_decoder.py +577 -0
- model/bitlinear.py +239 -0
- model/text_encoder.py +122 -0
- model/tokenizer.py +106 -0
README.md
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BitPixelLM
|
| 2 |
+
|
| 3 |
+
BitPixelLM is a text-to-pixel-art language model that generates 32x32 images from prompts like `a red pixel art sword`.
|
| 4 |
+
|
| 5 |
+
It uses a BitNet b1.58-style ternary decoder (`-1, 0, +1`) with a lightweight text encoder.
|
| 6 |
+
|
| 7 |
+
## Current Model Snapshot
|
| 8 |
+
|
| 9 |
+
- Model name: **BitPixelLM**
|
| 10 |
+
- Architecture: 3-layer text encoder + 6-layer BitPixelLM decoder
|
| 11 |
+
- Parameters: ~7.3M
|
| 12 |
+
- Dataset (v3): 23,648 synthetic pixel-art samples
|
| 13 |
+
- Vocab: 222 words
|
| 14 |
+
- Best validation loss (v3): ~0.4015
|
| 15 |
+
|
| 16 |
+
## Project Layout
|
| 17 |
+
|
| 18 |
+
- `model/bit_pixel_decoder.py` — BitPixelLM model
|
| 19 |
+
- `train_bitnet.py` — training pipeline
|
| 20 |
+
- `generate.py` — CLI generation
|
| 21 |
+
- `app.py` — Gradio app
|
| 22 |
+
- `data/generate_v3.py` — v3 dataset generator
|
| 23 |
+
- `PixelArtGen_Colab.ipynb` — Colab training notebook
|
| 24 |
+
|
| 25 |
+
## Run Locally
|
| 26 |
+
|
| 27 |
+
1. Ensure Python 3.9 + CUDA-enabled PyTorch.
|
| 28 |
+
2. Place data in `D:\PixelArtGen_Data\processed`:
|
| 29 |
+
- `tokens.npy`, `labels.json`, `vocab.json`, `palette_256.npy`
|
| 30 |
+
3. Train:
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
python train_bitnet.py --epochs 60 --batch-size 32 --lr 5e-4
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
4. Launch app:
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
python app.py
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## Publish to Hugging Face
|
| 43 |
+
|
| 44 |
+
This repo includes `publish_hf.py` for one-step upload.
|
| 45 |
+
|
| 46 |
+
### Required
|
| 47 |
+
|
| 48 |
+
- Hugging Face token with write access (`HF_TOKEN`)
|
| 49 |
+
- `huggingface_hub` installed
|
| 50 |
+
|
| 51 |
+
### Command
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
pip install huggingface_hub
|
| 55 |
+
python publish_hf.py --repo-id YOUR_USERNAME/BitPixelLM --token $HF_TOKEN
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
On Windows PowerShell:
|
| 59 |
+
|
| 60 |
+
```powershell
|
| 61 |
+
$env:HF_TOKEN = "hf_xxx"
|
| 62 |
+
python publish_hf.py --repo-id YOUR_USERNAME/BitPixelLM --token $env:HF_TOKEN
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
This uploads:
|
| 66 |
+
|
| 67 |
+
- `checkpoints_bit/best.pt`
|
| 68 |
+
- `model/` Python files
|
| 69 |
+
- `generate.py`
|
| 70 |
+
- `app.py`
|
| 71 |
+
- `README.md` (model card / usage overview)
|
| 72 |
+
|
| 73 |
+
## Notes
|
| 74 |
+
|
| 75 |
+
- The active production model is **BitPixelLM**.
|
| 76 |
+
- Legacy FP32 `PixelLM` artifacts remain in the repo only for historical reference.
|
app.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PixelArtGen — Gradio Web UI
|
| 3 |
+
|
| 4 |
+
Interactive UI to generate pixel art from text prompts using
|
| 5 |
+
BitPixelLM — a 1.58-bit ternary transformer (BitNet b1.58).
|
| 6 |
+
|
| 7 |
+
Launch:
|
| 8 |
+
python app.py
|
| 9 |
+
Then open http://localhost:7860 in your browser.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
import json
|
| 14 |
+
import torch
|
| 15 |
+
import numpy as np
|
| 16 |
+
import gradio as gr
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 21 |
+
|
| 22 |
+
from model.tokenizer import PaletteTokenizer
|
| 23 |
+
from model.text_encoder import TextTokenizer, TextEncoder
|
| 24 |
+
from model.bit_pixel_decoder import BitPixelLMDecoder, BitPixelLM
|
| 25 |
+
|
| 26 |
+
# ─── Config ──────────────────────────────────────────────────────
|
| 27 |
+
DATA_DIR = Path(r"D:\PixelArtGen_Data\processed")
|
| 28 |
+
CHECKPOINT_PATH = Path("checkpoints_bit/best.pt")
|
| 29 |
+
|
| 30 |
+
# ─── Global state (loaded once) ─────────────────────────────────
|
| 31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
+
model = None
|
| 33 |
+
palette_tok = None
|
| 34 |
+
text_tok = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def load_tokenizers():
|
| 38 |
+
"""Load shared tokenizers."""
|
| 39 |
+
global palette_tok, text_tok
|
| 40 |
+
palette_tok = PaletteTokenizer(palette_path=str(DATA_DIR / "palette_256.npy"))
|
| 41 |
+
with open(DATA_DIR / "vocab.json") as f:
|
| 42 |
+
vocab = json.load(f)
|
| 43 |
+
text_tok = TextTokenizer(vocab)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_model():
|
| 47 |
+
"""Load the BitPixelLM model from checkpoint."""
|
| 48 |
+
global model
|
| 49 |
+
if model is not None:
|
| 50 |
+
return model
|
| 51 |
+
|
| 52 |
+
if not CHECKPOINT_PATH.exists():
|
| 53 |
+
raise FileNotFoundError(
|
| 54 |
+
f"Checkpoint not found: {CHECKPOINT_PATH}\n"
|
| 55 |
+
"BitPixelLM is still training — check back once training completes."
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
checkpoint = torch.load(str(CHECKPOINT_PATH), map_location=device, weights_only=False)
|
| 59 |
+
model_args = checkpoint.get("args", {})
|
| 60 |
+
|
| 61 |
+
d_model = model_args.get("d_model", 256)
|
| 62 |
+
nhead = model_args.get("nhead", 8)
|
| 63 |
+
text_layers = model_args.get("text_layers", 3)
|
| 64 |
+
pixel_layers = model_args.get("pixel_layers", 6)
|
| 65 |
+
dim_ff = model_args.get("dim_ff", 512)
|
| 66 |
+
dropout = model_args.get("dropout", 0.1)
|
| 67 |
+
max_text_len = model_args.get("max_text_len", 32)
|
| 68 |
+
|
| 69 |
+
text_encoder = TextEncoder(
|
| 70 |
+
vocab_size=text_tok.vocab_size,
|
| 71 |
+
d_model=d_model,
|
| 72 |
+
nhead=nhead,
|
| 73 |
+
num_layers=text_layers,
|
| 74 |
+
dim_feedforward=dim_ff,
|
| 75 |
+
max_seq_len=max_text_len,
|
| 76 |
+
dropout=dropout,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
pixel_decoder = BitPixelLMDecoder(
|
| 80 |
+
vocab_size=palette_tok.vocab_size,
|
| 81 |
+
d_model=d_model,
|
| 82 |
+
nhead=nhead,
|
| 83 |
+
num_layers=pixel_layers,
|
| 84 |
+
dim_feedforward=dim_ff,
|
| 85 |
+
img_size=32,
|
| 86 |
+
dropout=dropout,
|
| 87 |
+
)
|
| 88 |
+
m = BitPixelLM(text_encoder, pixel_decoder).to(device)
|
| 89 |
+
|
| 90 |
+
m.load_state_dict(checkpoint["model_state_dict"])
|
| 91 |
+
m.eval()
|
| 92 |
+
model = m
|
| 93 |
+
return model
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def generate(
|
| 97 |
+
prompt: str,
|
| 98 |
+
temperature: float,
|
| 99 |
+
top_k: int,
|
| 100 |
+
top_p: float,
|
| 101 |
+
num_samples: int,
|
| 102 |
+
scale: int,
|
| 103 |
+
):
|
| 104 |
+
"""Generate pixel art from a text prompt."""
|
| 105 |
+
if not prompt.strip():
|
| 106 |
+
raise gr.Error("Please enter a prompt.")
|
| 107 |
+
|
| 108 |
+
if model is None:
|
| 109 |
+
raise gr.Error(
|
| 110 |
+
"BitPixelLM is not loaded yet. "
|
| 111 |
+
"It may still be training — check back once training completes."
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
text_tokens = text_tok.encode(prompt).unsqueeze(0).to(device)
|
| 115 |
+
|
| 116 |
+
# Warn about unknown words (still generates, but quality may suffer)
|
| 117 |
+
words = prompt.lower().strip().split()
|
| 118 |
+
unknown = [w for w in words if w not in text_tok.word2idx and w not in ("<pad>", "<sos>", "<eos>", "<unk>")]
|
| 119 |
+
|
| 120 |
+
images = []
|
| 121 |
+
try:
|
| 122 |
+
for _ in range(int(num_samples)):
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
generated_tokens = model.generate(
|
| 125 |
+
text_tokens,
|
| 126 |
+
sos_token=palette_tok.sos_token,
|
| 127 |
+
eos_token=palette_tok.eos_token,
|
| 128 |
+
temperature=temperature,
|
| 129 |
+
top_k=top_k,
|
| 130 |
+
top_p=top_p,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
token_list = generated_tokens[0].cpu().tolist()
|
| 134 |
+
img_array = palette_tok.decode_tokens(token_list)
|
| 135 |
+
img = Image.fromarray(img_array, "RGB")
|
| 136 |
+
|
| 137 |
+
# Upscale with nearest-neighbor for crisp pixels
|
| 138 |
+
s = int(scale)
|
| 139 |
+
if s > 1:
|
| 140 |
+
img = img.resize((32 * s, 32 * s), Image.NEAREST)
|
| 141 |
+
|
| 142 |
+
images.append(img)
|
| 143 |
+
except Exception as e:
|
| 144 |
+
raise gr.Error(f"Generation failed: {e}")
|
| 145 |
+
|
| 146 |
+
if unknown:
|
| 147 |
+
gr.Warning(
|
| 148 |
+
f"Unknown words treated as <unk>: {', '.join(unknown)}. "
|
| 149 |
+
f"Try using words from the vocabulary list below."
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return images
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ─── Build UI ─────────────────────���──────────────────────────────
|
| 156 |
+
|
| 157 |
+
# Load vocabulary dynamically from processed data
|
| 158 |
+
def _load_vocab_words():
|
| 159 |
+
try:
|
| 160 |
+
with open(DATA_DIR / "vocab.json") as f:
|
| 161 |
+
vocab = json.load(f)
|
| 162 |
+
return sorted([w for w in vocab if not w.startswith("<")])
|
| 163 |
+
except Exception:
|
| 164 |
+
return ["pixel", "art", "sword", "red", "blue", "green"]
|
| 165 |
+
|
| 166 |
+
VOCAB_WORDS = _load_vocab_words()
|
| 167 |
+
|
| 168 |
+
EXAMPLE_PROMPTS = [
|
| 169 |
+
"a red pixel art sword",
|
| 170 |
+
"a green pixel art dragon",
|
| 171 |
+
"a purple pixel art crystal",
|
| 172 |
+
"a blue pixel art knight",
|
| 173 |
+
"a gold pixel art castle",
|
| 174 |
+
"a red pixel art phoenix",
|
| 175 |
+
"a dark pixel art skeleton",
|
| 176 |
+
"a teal pixel art wizard",
|
| 177 |
+
"a silver pixel art robot",
|
| 178 |
+
"a orange pixel art fox",
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def build_ui():
|
| 183 |
+
with gr.Blocks(
|
| 184 |
+
title="PixelArtGen",
|
| 185 |
+
theme=gr.themes.Soft(primary_hue="purple"),
|
| 186 |
+
css="""
|
| 187 |
+
.gallery-item img { image-rendering: pixelated !important; }
|
| 188 |
+
.output-gallery img { image-rendering: pixelated !important; }
|
| 189 |
+
#gallery img { image-rendering: pixelated !important; }
|
| 190 |
+
""",
|
| 191 |
+
) as app:
|
| 192 |
+
gr.Markdown(
|
| 193 |
+
"""
|
| 194 |
+
# PixelArtGen
|
| 195 |
+
### Generate 32x32 pixel art from text prompts
|
| 196 |
+
|
| 197 |
+
Powered by **BitPixelLM** — a custom 1.58-bit ternary transformer built from scratch
|
| 198 |
+
using BitNet b1.58 with RMSNorm, SwiGLU, and 2D positional encoding.
|
| 199 |
+
7.3M parameters (75% ternary weights at 1.58 bits per weight).
|
| 200 |
+
"""
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
with gr.Row():
|
| 204 |
+
with gr.Column(scale=1):
|
| 205 |
+
prompt = gr.Textbox(
|
| 206 |
+
label="Prompt",
|
| 207 |
+
placeholder="a red pixel art sword",
|
| 208 |
+
lines=2,
|
| 209 |
+
)
|
| 210 |
+
with gr.Row():
|
| 211 |
+
generate_btn = gr.Button("Generate", variant="primary", scale=2)
|
| 212 |
+
num_samples = gr.Slider(1, 8, value=4, step=1, label="Samples")
|
| 213 |
+
|
| 214 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 215 |
+
temperature = gr.Slider(
|
| 216 |
+
0.1, 2.0, value=0.8, step=0.05,
|
| 217 |
+
label="Temperature",
|
| 218 |
+
info="Lower = more deterministic, higher = more creative"
|
| 219 |
+
)
|
| 220 |
+
top_k = gr.Slider(
|
| 221 |
+
0, 256, value=40, step=1,
|
| 222 |
+
label="Top-K",
|
| 223 |
+
info="0 = disabled. Limits sampling to top K tokens."
|
| 224 |
+
)
|
| 225 |
+
top_p = gr.Slider(
|
| 226 |
+
0.1, 1.0, value=0.9, step=0.05,
|
| 227 |
+
label="Top-P (Nucleus)",
|
| 228 |
+
info="Cumulative probability threshold for sampling."
|
| 229 |
+
)
|
| 230 |
+
scale = gr.Slider(
|
| 231 |
+
1, 16, value=8, step=1,
|
| 232 |
+
label="Upscale Factor",
|
| 233 |
+
info="8x = 256x256, 16x = 512x512"
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
gr.Markdown(
|
| 237 |
+
f"**Known vocabulary:** {', '.join(VOCAB_WORDS)}"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
with gr.Column(scale=2):
|
| 241 |
+
gallery = gr.Gallery(
|
| 242 |
+
label="Generated Pixel Art",
|
| 243 |
+
columns=4,
|
| 244 |
+
rows=2,
|
| 245 |
+
height=520,
|
| 246 |
+
object_fit="contain",
|
| 247 |
+
elem_id="gallery",
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
gr.Markdown("### Examples")
|
| 251 |
+
gr.Examples(
|
| 252 |
+
examples=EXAMPLE_PROMPTS,
|
| 253 |
+
inputs=[prompt],
|
| 254 |
+
label="Click to try",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
gr.Markdown(
|
| 258 |
+
"""
|
| 259 |
+
---
|
| 260 |
+
**Architecture:**
|
| 261 |
+
BitPixelLM treats pixel art generation as language modeling — each pixel is a token from a 256-color palette,
|
| 262 |
+
generated left-to-right, top-to-bottom via a causal transformer with 2D positional encoding and cross-attention to text.
|
| 263 |
+
Uses 1.58-bit ternary weights (BitNet b1.58) with RMSNorm and SwiGLU for extreme parameter efficiency.
|
| 264 |
+
"""
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Wire up the generate button
|
| 268 |
+
generate_btn.click(
|
| 269 |
+
fn=generate,
|
| 270 |
+
inputs=[prompt, temperature, top_k, top_p, num_samples, scale],
|
| 271 |
+
outputs=gallery,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Also generate on Enter
|
| 275 |
+
prompt.submit(
|
| 276 |
+
fn=generate,
|
| 277 |
+
inputs=[prompt, temperature, top_k, top_p, num_samples, scale],
|
| 278 |
+
outputs=gallery,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
return app
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# ─── Main ────────────────────────────────────────────────────────
|
| 285 |
+
if __name__ == "__main__":
|
| 286 |
+
print("Loading tokenizers...")
|
| 287 |
+
load_tokenizers()
|
| 288 |
+
print(f" Palette: {palette_tok.vocab_size} tokens")
|
| 289 |
+
print(f" Text: {text_tok.vocab_size} words")
|
| 290 |
+
print(f" Device: {device}")
|
| 291 |
+
|
| 292 |
+
# Load BitPixelLM
|
| 293 |
+
print(f"Loading BitPixelLM from {CHECKPOINT_PATH}...")
|
| 294 |
+
try:
|
| 295 |
+
load_model()
|
| 296 |
+
print(f" BitPixelLM loaded successfully.")
|
| 297 |
+
except FileNotFoundError as e:
|
| 298 |
+
print(f" {e}")
|
| 299 |
+
print(f" UI will launch but generation will be unavailable until training completes.")
|
| 300 |
+
except Exception as e:
|
| 301 |
+
print(f" Failed to load BitPixelLM: {e}")
|
| 302 |
+
|
| 303 |
+
print("\nLaunching UI...")
|
| 304 |
+
app = build_ui()
|
| 305 |
+
app.launch(
|
| 306 |
+
server_name="0.0.0.0",
|
| 307 |
+
server_port=7860,
|
| 308 |
+
share=False,
|
| 309 |
+
inbrowser=True,
|
| 310 |
+
)
|
best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37ceef8a7d844445be4bc5730bcd683d1512aff084a6e872634b3184c58f2464
|
| 3 |
+
size 88732053
|
config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "BitPixelLM",
|
| 3 |
+
"architecture": "BitNet-b1.58-style autoregressive decoder",
|
| 4 |
+
"image_size": 32,
|
| 5 |
+
"task": "text-to-image (pixel art)",
|
| 6 |
+
"checkpoint_file": "best.pt"
|
| 7 |
+
}
|
generate.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PixelArtGen — Generate pixel art from text prompts.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python generate.py --prompt "a red pixel art sword" --output output.png
|
| 6 |
+
python generate.py --prompt "a blue pixel art heart" --output heart.png --temperature 0.7
|
| 7 |
+
python generate.py --batch-prompts prompts.txt --output-dir outputs/
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import json
|
| 13 |
+
import argparse
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from PIL import Image
|
| 18 |
+
|
| 19 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 20 |
+
|
| 21 |
+
from model.tokenizer import PaletteTokenizer
|
| 22 |
+
from model.text_encoder import TextTokenizer, TextEncoder
|
| 23 |
+
from model.pixel_decoder import PixelLMDecoder, PixelLM
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_model(checkpoint_path: str, data_dir: str, device: torch.device):
|
| 27 |
+
"""Load a trained PixelLM model from checkpoint."""
|
| 28 |
+
data_dir = Path(data_dir)
|
| 29 |
+
|
| 30 |
+
# Load checkpoint
|
| 31 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 32 |
+
model_args = checkpoint.get("args", {})
|
| 33 |
+
|
| 34 |
+
# Load tokenizers
|
| 35 |
+
palette_tok = PaletteTokenizer(palette_path=str(data_dir / "palette_256.npy"))
|
| 36 |
+
|
| 37 |
+
with open(data_dir / "vocab.json") as f:
|
| 38 |
+
vocab = json.load(f)
|
| 39 |
+
text_tok = TextTokenizer(vocab)
|
| 40 |
+
|
| 41 |
+
# Rebuild model
|
| 42 |
+
d_model = model_args.get("d_model", 256)
|
| 43 |
+
nhead = model_args.get("nhead", 8)
|
| 44 |
+
text_layers = model_args.get("text_layers", 3)
|
| 45 |
+
pixel_layers = model_args.get("pixel_layers", 6)
|
| 46 |
+
dim_ff = model_args.get("dim_ff", 512)
|
| 47 |
+
dropout = model_args.get("dropout", 0.1)
|
| 48 |
+
max_text_len = model_args.get("max_text_len", 32)
|
| 49 |
+
|
| 50 |
+
text_encoder = TextEncoder(
|
| 51 |
+
vocab_size=text_tok.vocab_size,
|
| 52 |
+
d_model=d_model,
|
| 53 |
+
nhead=nhead,
|
| 54 |
+
num_layers=text_layers,
|
| 55 |
+
dim_feedforward=dim_ff,
|
| 56 |
+
max_seq_len=max_text_len,
|
| 57 |
+
dropout=dropout,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
pixel_decoder = PixelLMDecoder(
|
| 61 |
+
vocab_size=palette_tok.vocab_size,
|
| 62 |
+
d_model=d_model,
|
| 63 |
+
nhead=nhead,
|
| 64 |
+
num_layers=pixel_layers,
|
| 65 |
+
dim_feedforward=dim_ff,
|
| 66 |
+
img_size=32,
|
| 67 |
+
dropout=dropout,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
model = PixelLM(text_encoder, pixel_decoder).to(device)
|
| 71 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 72 |
+
model.eval()
|
| 73 |
+
|
| 74 |
+
return model, palette_tok, text_tok
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def generate_pixel_art(
|
| 78 |
+
model: PixelLM,
|
| 79 |
+
palette_tok: PaletteTokenizer,
|
| 80 |
+
text_tok: TextTokenizer,
|
| 81 |
+
prompt: str,
|
| 82 |
+
device: torch.device,
|
| 83 |
+
temperature: float = 0.8,
|
| 84 |
+
top_k: int = 40,
|
| 85 |
+
top_p: float = 0.9,
|
| 86 |
+
scale: int = 8,
|
| 87 |
+
) -> Image.Image:
|
| 88 |
+
"""
|
| 89 |
+
Generate a 32×32 pixel art image from a text prompt.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
model: Trained PixelLM model
|
| 93 |
+
palette_tok: Color palette tokenizer
|
| 94 |
+
text_tok: Text tokenizer
|
| 95 |
+
prompt: Text description
|
| 96 |
+
device: torch device
|
| 97 |
+
temperature: Sampling temperature (lower = more deterministic)
|
| 98 |
+
top_k: Top-k filtering
|
| 99 |
+
top_p: Nucleus sampling threshold
|
| 100 |
+
scale: Upscale factor for display (8 = 256×256 output)
|
| 101 |
+
Returns:
|
| 102 |
+
PIL Image (32*scale × 32*scale)
|
| 103 |
+
"""
|
| 104 |
+
# Tokenize prompt
|
| 105 |
+
text_tokens = text_tok.encode(prompt).unsqueeze(0).to(device)
|
| 106 |
+
|
| 107 |
+
# Generate
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
generated_tokens = model.generate(
|
| 110 |
+
text_tokens,
|
| 111 |
+
sos_token=palette_tok.sos_token,
|
| 112 |
+
eos_token=palette_tok.eos_token,
|
| 113 |
+
temperature=temperature,
|
| 114 |
+
top_k=top_k,
|
| 115 |
+
top_p=top_p,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Decode to image
|
| 119 |
+
token_list = generated_tokens[0].cpu().tolist()
|
| 120 |
+
img_array = palette_tok.decode_tokens(token_list)
|
| 121 |
+
img = Image.fromarray(img_array, "RGB")
|
| 122 |
+
|
| 123 |
+
# Upscale with nearest-neighbor (pixel art style)
|
| 124 |
+
if scale > 1:
|
| 125 |
+
img = img.resize((32 * scale, 32 * scale), Image.NEAREST)
|
| 126 |
+
|
| 127 |
+
return img
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def main():
|
| 131 |
+
parser = argparse.ArgumentParser(description="Generate pixel art from text")
|
| 132 |
+
parser.add_argument("--prompt", type=str, help="Text prompt")
|
| 133 |
+
parser.add_argument("--output", type=str, default="output.png", help="Output file")
|
| 134 |
+
parser.add_argument("--checkpoint", type=str, default="checkpoints/best.pt")
|
| 135 |
+
parser.add_argument("--data-dir", type=str, default=r"D:\PixelArtGen_Data\processed")
|
| 136 |
+
parser.add_argument("--temperature", type=float, default=0.8)
|
| 137 |
+
parser.add_argument("--top-k", type=int, default=40)
|
| 138 |
+
parser.add_argument("--top-p", type=float, default=0.9)
|
| 139 |
+
parser.add_argument("--scale", type=int, default=8, help="Upscale factor")
|
| 140 |
+
parser.add_argument("--num-samples", type=int, default=1, help="Number of images to generate")
|
| 141 |
+
parser.add_argument("--batch-prompts", type=str, help="File with prompts (one per line)")
|
| 142 |
+
parser.add_argument("--output-dir", type=str, default="outputs")
|
| 143 |
+
|
| 144 |
+
args = parser.parse_args()
|
| 145 |
+
|
| 146 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 147 |
+
print(f"Device: {device}")
|
| 148 |
+
|
| 149 |
+
# Load model
|
| 150 |
+
print(f"Loading model from {args.checkpoint}...")
|
| 151 |
+
model, palette_tok, text_tok = load_model(args.checkpoint, args.data_dir, device)
|
| 152 |
+
print(f" Model: {model.count_parameters():,} parameters")
|
| 153 |
+
|
| 154 |
+
# Collect prompts
|
| 155 |
+
if args.batch_prompts:
|
| 156 |
+
with open(args.batch_prompts) as f:
|
| 157 |
+
prompts = [line.strip() for line in f if line.strip()]
|
| 158 |
+
elif args.prompt:
|
| 159 |
+
prompts = [args.prompt]
|
| 160 |
+
else:
|
| 161 |
+
prompts = [
|
| 162 |
+
"a red pixel art sword",
|
| 163 |
+
"a blue pixel art heart",
|
| 164 |
+
"a green pixel art tree",
|
| 165 |
+
"a purple pixel art gem",
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
# Generate
|
| 169 |
+
output_dir = Path(args.output_dir)
|
| 170 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 171 |
+
|
| 172 |
+
for i, prompt in enumerate(prompts):
|
| 173 |
+
print(f"\nGenerating: \"{prompt}\"")
|
| 174 |
+
for j in range(args.num_samples):
|
| 175 |
+
img = generate_pixel_art(
|
| 176 |
+
model, palette_tok, text_tok, prompt, device,
|
| 177 |
+
temperature=args.temperature,
|
| 178 |
+
top_k=args.top_k,
|
| 179 |
+
top_p=args.top_p,
|
| 180 |
+
scale=args.scale,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
if len(prompts) == 1 and args.num_samples == 1:
|
| 184 |
+
out_path = args.output
|
| 185 |
+
else:
|
| 186 |
+
safe_name = prompt.replace(" ", "_")[:30]
|
| 187 |
+
out_path = output_dir / f"{safe_name}_{j}.png"
|
| 188 |
+
|
| 189 |
+
img.save(str(out_path))
|
| 190 |
+
print(f" Saved: {out_path}")
|
| 191 |
+
|
| 192 |
+
print("\nDone!")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
main()
|
model/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PixelArtGen model package."""
|
| 2 |
+
|
| 3 |
+
from .tokenizer import PaletteTokenizer
|
| 4 |
+
from .text_encoder import TextTokenizer, TextEncoder
|
| 5 |
+
from .bitlinear import BitLinear158, RMSNorm, SwiGLU
|
| 6 |
+
from .bit_pixel_decoder import BitPixelLMDecoder, BitPixelLM
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"PaletteTokenizer",
|
| 10 |
+
"TextTokenizer",
|
| 11 |
+
"TextEncoder",
|
| 12 |
+
"BitLinear158",
|
| 13 |
+
"RMSNorm",
|
| 14 |
+
"SwiGLU",
|
| 15 |
+
"BitPixelLMDecoder",
|
| 16 |
+
"BitPixelLM",
|
| 17 |
+
]
|
model/bit_pixel_decoder.py
ADDED
|
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PixelArtGen — BitPixelLM Decoder (1.58-bit)
|
| 3 |
+
|
| 4 |
+
A ternary-weight variant of our PixelLM decoder, implementing BitNet b1.58.
|
| 5 |
+
Replaces nn.Linear layers with BitLinear158 (ternary weights {-1, 0, +1})
|
| 6 |
+
and uses modern LLaMA-alike components (RMSNorm, SwiGLU, no biases).
|
| 7 |
+
|
| 8 |
+
Key differences from the standard PixelLM decoder:
|
| 9 |
+
- BitLinear158 layers with built-in RMSNorm (replaces nn.Linear + LayerNorm)
|
| 10 |
+
- SwiGLU FFN activation (replaces GELU)
|
| 11 |
+
- No biases anywhere
|
| 12 |
+
- Token embeddings and output head remain in full precision
|
| 13 |
+
- 2D positional encoding preserved (our unique contribution)
|
| 14 |
+
|
| 15 |
+
References:
|
| 16 |
+
- "The Era of 1-bit LLMs" (Ma et al., 2024) — arXiv:2402.17764
|
| 17 |
+
- "BitNet" (Wang et al., 2023) — arXiv:2310.11453
|
| 18 |
+
- "GLU Variants Improve Transformer" (Shazeer, 2020) — arXiv:2002.05202
|
| 19 |
+
- "RMSNorm" (Zhang & Sennrich, 2019) — arXiv:1910.07467
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
from typing import Optional
|
| 27 |
+
|
| 28 |
+
from model.bitlinear import BitLinear158, RMSNorm, SwiGLU
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ── Shared components (self-contained, no dependency on pixel_decoder.py) ──
|
| 32 |
+
|
| 33 |
+
class PixelPositionalEncoding2D(nn.Module):
|
| 34 |
+
"""
|
| 35 |
+
2D positional encoding for pixel sequences.
|
| 36 |
+
|
| 37 |
+
Instead of treating pixel positions as flat indices 0..1023,
|
| 38 |
+
we encode them as (row, col) pairs with separate learned embeddings.
|
| 39 |
+
This gives the model explicit 2D spatial structure.
|
| 40 |
+
|
| 41 |
+
Also includes a special position embedding for <sos> and <eos> tokens.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, d_model: int, img_size: int = 32):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.img_size = img_size
|
| 47 |
+
self.d_model = d_model
|
| 48 |
+
|
| 49 |
+
# Separate row and column embeddings
|
| 50 |
+
self.row_embed = nn.Embedding(img_size, d_model // 2)
|
| 51 |
+
self.col_embed = nn.Embedding(img_size, d_model // 2)
|
| 52 |
+
|
| 53 |
+
# Special position for sos/eos tokens
|
| 54 |
+
self.special_pos = nn.Embedding(2, d_model) # 0=sos, 1=eos
|
| 55 |
+
|
| 56 |
+
# Learnable scale
|
| 57 |
+
self.scale = nn.Parameter(torch.ones(1))
|
| 58 |
+
|
| 59 |
+
def forward(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 60 |
+
"""
|
| 61 |
+
Generate positional encodings for a sequence of length seq_len.
|
| 62 |
+
Sequence layout: [sos, pixel_0, pixel_1, ..., pixel_1023, eos]
|
| 63 |
+
|
| 64 |
+
Returns: (1, seq_len, d_model)
|
| 65 |
+
"""
|
| 66 |
+
positions = torch.zeros(1, seq_len, self.d_model, device=device)
|
| 67 |
+
|
| 68 |
+
# SOS position
|
| 69 |
+
positions[:, 0, :] = self.special_pos(torch.tensor([0], device=device))
|
| 70 |
+
|
| 71 |
+
# Pixel positions (indices 1..1024)
|
| 72 |
+
num_pixels = min(seq_len - 1, self.img_size * self.img_size)
|
| 73 |
+
if num_pixels > 0:
|
| 74 |
+
pixel_indices = torch.arange(num_pixels, device=device)
|
| 75 |
+
rows = pixel_indices // self.img_size
|
| 76 |
+
cols = pixel_indices % self.img_size
|
| 77 |
+
|
| 78 |
+
row_emb = self.row_embed(rows) # (num_pixels, d_model//2)
|
| 79 |
+
col_emb = self.col_embed(cols) # (num_pixels, d_model//2)
|
| 80 |
+
pixel_pos = torch.cat([row_emb, col_emb], dim=-1) # (num_pixels, d_model)
|
| 81 |
+
positions[:, 1:1 + num_pixels, :] = pixel_pos.unsqueeze(0)
|
| 82 |
+
|
| 83 |
+
# EOS position (if present)
|
| 84 |
+
if seq_len > self.img_size * self.img_size + 1:
|
| 85 |
+
positions[:, -1, :] = self.special_pos(torch.tensor([1], device=device))
|
| 86 |
+
|
| 87 |
+
return positions * self.scale
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class PaletteOutputHead(nn.Module):
|
| 91 |
+
"""
|
| 92 |
+
Palette-aware output prediction.
|
| 93 |
+
|
| 94 |
+
Instead of a flat linear(d_model -> vocab_size) layer, we compute
|
| 95 |
+
output logits via scaled dot-product attention between the decoder
|
| 96 |
+
hidden states and a set of learned palette key vectors.
|
| 97 |
+
|
| 98 |
+
Each palette color has a key embedding initialized from its RGB values.
|
| 99 |
+
This gives the model an inductive bias toward understanding color relationships.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def __init__(self, d_model: int, palette_size: int, num_special_tokens: int = 3):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.total_vocab = palette_size + num_special_tokens
|
| 105 |
+
self.d_model = d_model
|
| 106 |
+
|
| 107 |
+
# Learned palette keys (will be initialized from RGB values)
|
| 108 |
+
self.palette_keys = nn.Parameter(torch.randn(self.total_vocab, d_model))
|
| 109 |
+
|
| 110 |
+
# Query projection for hidden states
|
| 111 |
+
self.query_proj = nn.Linear(d_model, d_model)
|
| 112 |
+
|
| 113 |
+
# Temperature parameter for controlling sharpness
|
| 114 |
+
self.temperature = nn.Parameter(torch.tensor(math.sqrt(d_model), dtype=torch.float32))
|
| 115 |
+
|
| 116 |
+
def init_from_palette(self, palette_rgb: torch.Tensor):
|
| 117 |
+
"""
|
| 118 |
+
Initialize palette key embeddings from RGB values.
|
| 119 |
+
palette_rgb: (palette_size, 3) tensor of RGB values [0, 255]
|
| 120 |
+
"""
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
palette_size = palette_rgb.shape[0]
|
| 123 |
+
# Normalize RGB to [-1, 1] and project to d_model
|
| 124 |
+
rgb_norm = palette_rgb.float() / 127.5 - 1.0 # (palette_size, 3)
|
| 125 |
+
# Repeat/tile to fill d_model dimensions
|
| 126 |
+
repeats = self.d_model // 3 + 1
|
| 127 |
+
expanded = rgb_norm.repeat(1, repeats)[:, :self.d_model]
|
| 128 |
+
# Mix with some noise for diversity
|
| 129 |
+
self.palette_keys.data[:palette_size] = expanded + 0.1 * torch.randn_like(expanded)
|
| 130 |
+
|
| 131 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 132 |
+
"""
|
| 133 |
+
Args:
|
| 134 |
+
hidden_states: (batch, seq_len, d_model)
|
| 135 |
+
Returns:
|
| 136 |
+
logits: (batch, seq_len, total_vocab)
|
| 137 |
+
"""
|
| 138 |
+
queries = self.query_proj(hidden_states) # (batch, seq_len, d_model)
|
| 139 |
+
# Scaled dot-product attention with palette keys
|
| 140 |
+
logits = torch.matmul(queries, self.palette_keys.T) / self.temperature
|
| 141 |
+
return logits
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class BitMultiheadAttention(nn.Module):
|
| 145 |
+
"""
|
| 146 |
+
Multi-head attention with BitLinear158 projections.
|
| 147 |
+
|
| 148 |
+
Q, K, V projections and the output projection all use 1.58-bit weights.
|
| 149 |
+
Attention computation itself remains in full precision.
|
| 150 |
+
|
| 151 |
+
Following BitNet b1.58: the RMSNorm that normally precedes attention
|
| 152 |
+
is absorbed into the BitLinear158 layers (they have built-in RMSNorm).
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def __init__(self, d_model: int, nhead: int, dropout: float = 0.0):
|
| 156 |
+
super().__init__()
|
| 157 |
+
assert d_model % nhead == 0, f"d_model ({d_model}) must be divisible by nhead ({nhead})"
|
| 158 |
+
|
| 159 |
+
self.d_model = d_model
|
| 160 |
+
self.nhead = nhead
|
| 161 |
+
self.head_dim = d_model // nhead
|
| 162 |
+
|
| 163 |
+
# QKV projections — all 1.58-bit
|
| 164 |
+
self.q_proj = BitLinear158(d_model, d_model)
|
| 165 |
+
self.k_proj = BitLinear158(d_model, d_model)
|
| 166 |
+
self.v_proj = BitLinear158(d_model, d_model)
|
| 167 |
+
|
| 168 |
+
# Output projection — 1.58-bit
|
| 169 |
+
self.out_proj = BitLinear158(d_model, d_model)
|
| 170 |
+
|
| 171 |
+
self.dropout = nn.Dropout(dropout)
|
| 172 |
+
self.scale = math.sqrt(self.head_dim)
|
| 173 |
+
|
| 174 |
+
def forward(
|
| 175 |
+
self,
|
| 176 |
+
query: torch.Tensor,
|
| 177 |
+
key: torch.Tensor,
|
| 178 |
+
value: torch.Tensor,
|
| 179 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 180 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
| 181 |
+
) -> torch.Tensor:
|
| 182 |
+
"""
|
| 183 |
+
Args:
|
| 184 |
+
query: (batch, q_len, d_model)
|
| 185 |
+
key: (batch, kv_len, d_model)
|
| 186 |
+
value: (batch, kv_len, d_model)
|
| 187 |
+
attn_mask: (q_len, kv_len) or (batch*nhead, q_len, kv_len)
|
| 188 |
+
key_padding_mask: (batch, kv_len)
|
| 189 |
+
Returns:
|
| 190 |
+
(batch, q_len, d_model)
|
| 191 |
+
"""
|
| 192 |
+
batch_size = query.size(0)
|
| 193 |
+
|
| 194 |
+
# Project Q, K, V through 1.58-bit linear layers
|
| 195 |
+
q = self.q_proj(query)
|
| 196 |
+
k = self.k_proj(key)
|
| 197 |
+
v = self.v_proj(value)
|
| 198 |
+
|
| 199 |
+
# Reshape for multi-head: (batch, seq, d_model) -> (batch, nhead, seq, head_dim)
|
| 200 |
+
q = q.view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2)
|
| 201 |
+
k = k.view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2)
|
| 202 |
+
v = v.view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2)
|
| 203 |
+
|
| 204 |
+
# Scaled dot-product attention
|
| 205 |
+
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale
|
| 206 |
+
|
| 207 |
+
# Apply causal mask
|
| 208 |
+
if attn_mask is not None:
|
| 209 |
+
if attn_mask.dim() == 2:
|
| 210 |
+
attn_weights = attn_weights + attn_mask.unsqueeze(0).unsqueeze(0)
|
| 211 |
+
else:
|
| 212 |
+
attn_weights = attn_weights + attn_mask
|
| 213 |
+
|
| 214 |
+
# Apply padding mask
|
| 215 |
+
if key_padding_mask is not None:
|
| 216 |
+
attn_weights = attn_weights.masked_fill(
|
| 217 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
| 218 |
+
float('-inf')
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 222 |
+
attn_weights = self.dropout(attn_weights)
|
| 223 |
+
|
| 224 |
+
# Apply attention to values
|
| 225 |
+
attn_output = torch.matmul(attn_weights, v)
|
| 226 |
+
|
| 227 |
+
# Reshape back: (batch, nhead, seq, head_dim) -> (batch, seq, d_model)
|
| 228 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
|
| 229 |
+
|
| 230 |
+
# Output projection (1.58-bit)
|
| 231 |
+
return self.out_proj(attn_output)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class BitPixelLMDecoderLayer(nn.Module):
|
| 235 |
+
"""
|
| 236 |
+
Single decoder layer with 1.58-bit weights.
|
| 237 |
+
|
| 238 |
+
Structure (per BitNet b1.58 / LLaMA convention):
|
| 239 |
+
1. Self-attention with BitLinear158 projections (RMSNorm built into BitLinear)
|
| 240 |
+
2. Cross-attention to text encoder output (BitLinear158 projections)
|
| 241 |
+
3. SwiGLU feed-forward network (BitLinear158 projections)
|
| 242 |
+
|
| 243 |
+
Pre-norm architecture, but the norm is absorbed into BitLinear158.
|
| 244 |
+
Residual connections use a separate RMSNorm for gradient stability.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(self, d_model: int, nhead: int, dim_ff: int, dropout: float = 0.0):
|
| 248 |
+
super().__init__()
|
| 249 |
+
|
| 250 |
+
# Self-attention (masked, causal)
|
| 251 |
+
self.self_attn = BitMultiheadAttention(d_model, nhead, dropout=dropout)
|
| 252 |
+
self.norm1 = RMSNorm(d_model)
|
| 253 |
+
|
| 254 |
+
# Cross-attention to text
|
| 255 |
+
self.cross_attn = BitMultiheadAttention(d_model, nhead, dropout=dropout)
|
| 256 |
+
self.norm2 = RMSNorm(d_model)
|
| 257 |
+
|
| 258 |
+
# SwiGLU feed-forward (replaces GELU FFN)
|
| 259 |
+
self.ff = SwiGLU(d_model, hidden_features=dim_ff, use_bitlinear=True)
|
| 260 |
+
self.norm3 = RMSNorm(d_model)
|
| 261 |
+
|
| 262 |
+
self.dropout = nn.Dropout(dropout)
|
| 263 |
+
|
| 264 |
+
def forward(
|
| 265 |
+
self,
|
| 266 |
+
x: torch.Tensor,
|
| 267 |
+
text_enc: torch.Tensor,
|
| 268 |
+
causal_mask: torch.Tensor,
|
| 269 |
+
text_pad_mask: Optional[torch.Tensor] = None,
|
| 270 |
+
) -> torch.Tensor:
|
| 271 |
+
"""
|
| 272 |
+
Args:
|
| 273 |
+
x: (batch, seq_len, d_model)
|
| 274 |
+
text_enc: (batch, text_len, d_model)
|
| 275 |
+
causal_mask: (seq_len, seq_len) causal attention mask
|
| 276 |
+
text_pad_mask: (batch, text_len) padding mask for text
|
| 277 |
+
Returns:
|
| 278 |
+
(batch, seq_len, d_model)
|
| 279 |
+
"""
|
| 280 |
+
# Pre-norm self-attention with residual
|
| 281 |
+
residual = x
|
| 282 |
+
x = self.norm1(x)
|
| 283 |
+
x = self.self_attn(x, x, x, attn_mask=causal_mask)
|
| 284 |
+
x = self.dropout(x) + residual
|
| 285 |
+
|
| 286 |
+
# Pre-norm cross-attention with residual
|
| 287 |
+
residual = x
|
| 288 |
+
x = self.norm2(x)
|
| 289 |
+
x = self.cross_attn(x, text_enc, text_enc, key_padding_mask=text_pad_mask)
|
| 290 |
+
x = self.dropout(x) + residual
|
| 291 |
+
|
| 292 |
+
# Pre-norm SwiGLU FFN with residual
|
| 293 |
+
residual = x
|
| 294 |
+
x = self.norm3(x)
|
| 295 |
+
x = self.ff(x)
|
| 296 |
+
x = self.dropout(x) + residual
|
| 297 |
+
|
| 298 |
+
return x
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class BitPixelLMDecoder(nn.Module):
|
| 302 |
+
"""
|
| 303 |
+
1.58-bit PixelLM Decoder.
|
| 304 |
+
|
| 305 |
+
Same architecture as PixelLMDecoder but with:
|
| 306 |
+
- BitLinear158 replacing all nn.Linear in attention and FFN
|
| 307 |
+
- RMSNorm replacing LayerNorm (absorbed into BitLinear + residual norms)
|
| 308 |
+
- SwiGLU replacing GELU FFN
|
| 309 |
+
- No biases
|
| 310 |
+
|
| 311 |
+
Full precision components (NOT quantized):
|
| 312 |
+
- Token embeddings (need full precision for gradient flow to embeddings)
|
| 313 |
+
- 2D positional encoding (our unique spatial encoding)
|
| 314 |
+
- Palette output head (needs high-precision logits for sampling)
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
def __init__(
|
| 318 |
+
self,
|
| 319 |
+
vocab_size: int,
|
| 320 |
+
d_model: int = 256,
|
| 321 |
+
nhead: int = 8,
|
| 322 |
+
num_layers: int = 6,
|
| 323 |
+
dim_feedforward: int = 512,
|
| 324 |
+
img_size: int = 32,
|
| 325 |
+
dropout: float = 0.1,
|
| 326 |
+
):
|
| 327 |
+
super().__init__()
|
| 328 |
+
self.d_model = d_model
|
| 329 |
+
self.vocab_size = vocab_size
|
| 330 |
+
self.img_size = img_size
|
| 331 |
+
self.max_seq_len = img_size * img_size + 2
|
| 332 |
+
|
| 333 |
+
# ── Full precision components ─────────────────────────────
|
| 334 |
+
# Token embedding (kept in FP32)
|
| 335 |
+
self.token_embed = nn.Embedding(vocab_size, d_model)
|
| 336 |
+
|
| 337 |
+
# 2D positional encoding (our unique contribution — kept FP32)
|
| 338 |
+
self.pos_encoding = PixelPositionalEncoding2D(d_model, img_size)
|
| 339 |
+
|
| 340 |
+
# Palette-aware output head (kept FP32 for sampling precision)
|
| 341 |
+
self.output_head = PaletteOutputHead(d_model, vocab_size - 3, num_special_tokens=3)
|
| 342 |
+
|
| 343 |
+
# ── 1.58-bit components ───────────────────────────────────
|
| 344 |
+
# Decoder layers with BitLinear158
|
| 345 |
+
self.layers = nn.ModuleList([
|
| 346 |
+
BitPixelLMDecoderLayer(d_model, nhead, dim_feedforward, dropout)
|
| 347 |
+
for _ in range(num_layers)
|
| 348 |
+
])
|
| 349 |
+
|
| 350 |
+
# Final norm (full precision RMSNorm)
|
| 351 |
+
self.final_norm = RMSNorm(d_model)
|
| 352 |
+
|
| 353 |
+
# Dropout
|
| 354 |
+
self.dropout = nn.Dropout(dropout)
|
| 355 |
+
|
| 356 |
+
# Cache for causal mask
|
| 357 |
+
self._causal_mask_cache = {}
|
| 358 |
+
|
| 359 |
+
def _get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 360 |
+
"""Generate or retrieve cached causal attention mask."""
|
| 361 |
+
if seq_len not in self._causal_mask_cache:
|
| 362 |
+
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
|
| 363 |
+
float_mask = torch.zeros(seq_len, seq_len, device=device)
|
| 364 |
+
float_mask.masked_fill_(mask, float('-inf'))
|
| 365 |
+
self._causal_mask_cache[seq_len] = float_mask
|
| 366 |
+
return self._causal_mask_cache[seq_len]
|
| 367 |
+
|
| 368 |
+
def forward(
|
| 369 |
+
self,
|
| 370 |
+
pixel_tokens: torch.Tensor,
|
| 371 |
+
text_enc: torch.Tensor,
|
| 372 |
+
text_pad_mask: Optional[torch.Tensor] = None,
|
| 373 |
+
) -> torch.Tensor:
|
| 374 |
+
"""
|
| 375 |
+
Forward pass for training (teacher-forced).
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
pixel_tokens: (batch, seq_len) long tensor of pixel token indices
|
| 379 |
+
text_enc: (batch, text_len, d_model) text encoder output
|
| 380 |
+
text_pad_mask: (batch, text_len) True where text is padded
|
| 381 |
+
Returns:
|
| 382 |
+
logits: (batch, seq_len, vocab_size)
|
| 383 |
+
"""
|
| 384 |
+
batch_size, seq_len = pixel_tokens.shape
|
| 385 |
+
device = pixel_tokens.device
|
| 386 |
+
|
| 387 |
+
# Token embeddings (full precision)
|
| 388 |
+
x = self.token_embed(pixel_tokens) * math.sqrt(self.d_model)
|
| 389 |
+
|
| 390 |
+
# 2D positional encoding (full precision)
|
| 391 |
+
pos = self.pos_encoding(seq_len, device)
|
| 392 |
+
x = x + pos
|
| 393 |
+
x = self.dropout(x)
|
| 394 |
+
|
| 395 |
+
# Causal mask
|
| 396 |
+
causal_mask = self._get_causal_mask(seq_len, device)
|
| 397 |
+
|
| 398 |
+
# 1.58-bit decoder layers
|
| 399 |
+
for layer in self.layers:
|
| 400 |
+
x = layer(x, text_enc, causal_mask, text_pad_mask)
|
| 401 |
+
|
| 402 |
+
# Final norm
|
| 403 |
+
x = self.final_norm(x)
|
| 404 |
+
|
| 405 |
+
# Output logits via palette-aware head (full precision)
|
| 406 |
+
logits = self.output_head(x)
|
| 407 |
+
|
| 408 |
+
return logits
|
| 409 |
+
|
| 410 |
+
@torch.no_grad()
|
| 411 |
+
def generate(
|
| 412 |
+
self,
|
| 413 |
+
text_enc: torch.Tensor,
|
| 414 |
+
sos_token: int,
|
| 415 |
+
eos_token: int,
|
| 416 |
+
max_len: int = 1026,
|
| 417 |
+
temperature: float = 0.8,
|
| 418 |
+
top_k: int = 40,
|
| 419 |
+
top_p: float = 0.9,
|
| 420 |
+
text_pad_mask: Optional[torch.Tensor] = None,
|
| 421 |
+
) -> torch.Tensor:
|
| 422 |
+
"""
|
| 423 |
+
Autoregressive generation (same interface as PixelLMDecoder).
|
| 424 |
+
"""
|
| 425 |
+
device = text_enc.device
|
| 426 |
+
tokens = torch.tensor([[sos_token]], dtype=torch.long, device=device)
|
| 427 |
+
|
| 428 |
+
for step in range(max_len - 1):
|
| 429 |
+
logits = self.forward(tokens, text_enc, text_pad_mask)
|
| 430 |
+
next_logits = logits[:, -1, :] / temperature
|
| 431 |
+
|
| 432 |
+
# Top-k filtering
|
| 433 |
+
if top_k > 0:
|
| 434 |
+
topk_vals, _ = torch.topk(next_logits, top_k)
|
| 435 |
+
next_logits[next_logits < topk_vals[:, -1:]] = float('-inf')
|
| 436 |
+
|
| 437 |
+
# Top-p (nucleus) filtering
|
| 438 |
+
if top_p < 1.0:
|
| 439 |
+
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
|
| 440 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 441 |
+
sorted_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p
|
| 442 |
+
sorted_logits[sorted_mask] = float('-inf')
|
| 443 |
+
next_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
|
| 444 |
+
|
| 445 |
+
probs = F.softmax(next_logits, dim=-1)
|
| 446 |
+
next_token = torch.multinomial(probs, 1)
|
| 447 |
+
tokens = torch.cat([tokens, next_token], dim=1)
|
| 448 |
+
|
| 449 |
+
if next_token.item() == eos_token:
|
| 450 |
+
break
|
| 451 |
+
|
| 452 |
+
return tokens
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class BitPixelLM(nn.Module):
|
| 456 |
+
"""
|
| 457 |
+
Complete 1.58-bit PixelLM: Text Encoder (FP32) + Pixel Decoder (1.58-bit).
|
| 458 |
+
|
| 459 |
+
The text encoder remains in full precision because:
|
| 460 |
+
1. It's small (3 layers) — quantization overhead would negate benefits
|
| 461 |
+
2. Text understanding needs full precision for a small vocabulary
|
| 462 |
+
|
| 463 |
+
The pixel decoder uses 1.58-bit weights for:
|
| 464 |
+
1. All self-attention projections (Q, K, V, O)
|
| 465 |
+
2. All cross-attention projections
|
| 466 |
+
3. All FFN projections (SwiGLU)
|
| 467 |
+
"""
|
| 468 |
+
|
| 469 |
+
def __init__(self, text_encoder: nn.Module, pixel_decoder: BitPixelLMDecoder):
|
| 470 |
+
super().__init__()
|
| 471 |
+
self.text_encoder = text_encoder
|
| 472 |
+
self.pixel_decoder = pixel_decoder
|
| 473 |
+
|
| 474 |
+
def forward(
|
| 475 |
+
self,
|
| 476 |
+
text_tokens: torch.Tensor,
|
| 477 |
+
pixel_tokens: torch.Tensor,
|
| 478 |
+
) -> torch.Tensor:
|
| 479 |
+
text_pad_mask = (text_tokens == 0)
|
| 480 |
+
text_enc = self.text_encoder(text_tokens)
|
| 481 |
+
logits = self.pixel_decoder(pixel_tokens, text_enc, text_pad_mask)
|
| 482 |
+
return logits
|
| 483 |
+
|
| 484 |
+
@torch.no_grad()
|
| 485 |
+
def generate(
|
| 486 |
+
self,
|
| 487 |
+
text_tokens: torch.Tensor,
|
| 488 |
+
sos_token: int,
|
| 489 |
+
eos_token: int,
|
| 490 |
+
**kwargs,
|
| 491 |
+
) -> torch.Tensor:
|
| 492 |
+
text_pad_mask = (text_tokens == 0)
|
| 493 |
+
text_enc = self.text_encoder(text_tokens)
|
| 494 |
+
return self.pixel_decoder.generate(
|
| 495 |
+
text_enc, sos_token, eos_token,
|
| 496 |
+
text_pad_mask=text_pad_mask, **kwargs
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
def count_parameters(self) -> int:
|
| 500 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 501 |
+
|
| 502 |
+
def count_bit_parameters(self) -> dict:
|
| 503 |
+
"""Count parameters by precision level."""
|
| 504 |
+
bit_params = 0
|
| 505 |
+
fp_params = 0
|
| 506 |
+
for name, p in self.named_parameters():
|
| 507 |
+
if not p.requires_grad:
|
| 508 |
+
continue
|
| 509 |
+
if 'pixel_decoder.layers' in name and '.weight' in name and 'norm' not in name and 'rms_norm' not in name:
|
| 510 |
+
bit_params += p.numel()
|
| 511 |
+
else:
|
| 512 |
+
fp_params += p.numel()
|
| 513 |
+
return {
|
| 514 |
+
'ternary_params': bit_params,
|
| 515 |
+
'fp32_params': fp_params,
|
| 516 |
+
'total': bit_params + fp_params,
|
| 517 |
+
'ternary_pct': bit_params / (bit_params + fp_params) * 100,
|
| 518 |
+
'effective_bits': (bit_params * 1.58 + fp_params * 32) / (bit_params + fp_params),
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
# ──── Testing ────────────────────────────────────────────────────
|
| 523 |
+
|
| 524 |
+
if __name__ == "__main__":
|
| 525 |
+
import sys
|
| 526 |
+
sys.path.insert(0, str(__import__('pathlib').Path(__file__).parent.parent))
|
| 527 |
+
|
| 528 |
+
from model.text_encoder import TextEncoder
|
| 529 |
+
|
| 530 |
+
print("Building BitPixelLM...")
|
| 531 |
+
|
| 532 |
+
# Build text encoder (full precision)
|
| 533 |
+
text_encoder = TextEncoder(
|
| 534 |
+
vocab_size=66, # 62 words + 4 special
|
| 535 |
+
d_model=256,
|
| 536 |
+
nhead=4,
|
| 537 |
+
num_layers=3,
|
| 538 |
+
dim_feedforward=512,
|
| 539 |
+
max_seq_len=32,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
# Build 1.58-bit pixel decoder
|
| 543 |
+
pixel_decoder = BitPixelLMDecoder(
|
| 544 |
+
vocab_size=259,
|
| 545 |
+
d_model=256,
|
| 546 |
+
nhead=8,
|
| 547 |
+
num_layers=6,
|
| 548 |
+
dim_feedforward=512,
|
| 549 |
+
img_size=32,
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
model = BitPixelLM(text_encoder, pixel_decoder)
|
| 553 |
+
|
| 554 |
+
# Parameter count
|
| 555 |
+
total = model.count_parameters()
|
| 556 |
+
breakdown = model.count_bit_parameters()
|
| 557 |
+
print(f"\nBitPixelLM: {total:,} total parameters")
|
| 558 |
+
print(f" Ternary (1.58-bit): {breakdown['ternary_params']:,} ({breakdown['ternary_pct']:.1f}%)")
|
| 559 |
+
print(f" Full precision: {breakdown['fp32_params']:,} ({100-breakdown['ternary_pct']:.1f}%)")
|
| 560 |
+
print(f" Effective bits/param: {breakdown['effective_bits']:.2f}")
|
| 561 |
+
|
| 562 |
+
# Forward pass test
|
| 563 |
+
text = torch.randint(0, 66, (2, 32))
|
| 564 |
+
pixels = torch.randint(0, 259, (2, 1025))
|
| 565 |
+
|
| 566 |
+
print(f"\nForward pass test...")
|
| 567 |
+
logits = model(text, pixels)
|
| 568 |
+
print(f" Input: text={text.shape}, pixels={pixels.shape}")
|
| 569 |
+
print(f" Output: logits={logits.shape}")
|
| 570 |
+
|
| 571 |
+
# Gradient test
|
| 572 |
+
loss = logits[:, :, :259].sum()
|
| 573 |
+
loss.backward()
|
| 574 |
+
grad_ok = all(p.grad is not None for p in model.parameters() if p.requires_grad)
|
| 575 |
+
print(f" Gradient flow: {'OK' if grad_ok else 'FAILED'}")
|
| 576 |
+
|
| 577 |
+
print("\nAll tests passed! ✓")
|
model/bitlinear.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PixelArtGen — BitLinear 1.58-bit Layer & RMSNorm
|
| 3 |
+
|
| 4 |
+
Implementation of the core BitNet b1.58 components:
|
| 5 |
+
- RMSNorm: Root Mean Square Layer Normalization (Zhang & Sennrich, 2019)
|
| 6 |
+
- BitLinear158: 1.58-bit linear layer with ternary weights {-1, 0, +1}
|
| 7 |
+
|
| 8 |
+
References:
|
| 9 |
+
- "The Era of 1-bit LLMs" (Ma et al., 2024) — arXiv:2402.17764
|
| 10 |
+
- "BitNet: Scaling 1-bit Transformers" (Wang et al., 2023) — arXiv:2310.11453
|
| 11 |
+
- "RMSNorm" (Zhang & Sennrich, 2019) — arXiv:1910.07467
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RMSNorm(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
Root Mean Square Layer Normalization.
|
| 23 |
+
|
| 24 |
+
Simpler and faster than LayerNorm — removes mean centering,
|
| 25 |
+
keeps only the re-scaling by root mean square.
|
| 26 |
+
|
| 27 |
+
RMSNorm(x) = x / RMS(x) * g
|
| 28 |
+
where RMS(x) = sqrt(mean(x^2))
|
| 29 |
+
|
| 30 |
+
Reference: arXiv:1910.07467
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.eps = eps
|
| 36 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 37 |
+
|
| 38 |
+
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
output = self._norm(x.float()).type_as(x)
|
| 43 |
+
return output * self.weight
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def activation_quant(x: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
"""
|
| 48 |
+
Per-token 8-bit activation quantization from BitNet b1.58.
|
| 49 |
+
|
| 50 |
+
Quantizes activations to [-127, 127] per-token using absmax scaling.
|
| 51 |
+
Symmetric quantization (no zero-point) as specified in the paper.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
x: (..., d_model) float tensor
|
| 55 |
+
Returns:
|
| 56 |
+
Quantized tensor (still float for autograd compatibility), scale factor
|
| 57 |
+
"""
|
| 58 |
+
Qb = 127 # 8-bit signed: 2^(8-1) - 1
|
| 59 |
+
scale = x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
|
| 60 |
+
x_quant = (x * Qb / scale).clamp(-Qb, Qb).round()
|
| 61 |
+
# STE: detach the rounding, keep gradients flowing
|
| 62 |
+
x_quant = x + (x_quant * scale / Qb - x).detach()
|
| 63 |
+
return x_quant
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def weight_quant(w: torch.Tensor) -> tuple:
|
| 67 |
+
"""
|
| 68 |
+
Absmean ternary weight quantization from BitNet b1.58.
|
| 69 |
+
|
| 70 |
+
Quantizes weights to {-1, 0, +1} using absmean scaling:
|
| 71 |
+
1. Compute gamma = mean(|W|)
|
| 72 |
+
2. Scale: W_scaled = W / gamma
|
| 73 |
+
3. Round to nearest in {-1, 0, +1}
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
w: (out_features, in_features) weight matrix
|
| 77 |
+
Returns:
|
| 78 |
+
(quantized_weights, scale_factor)
|
| 79 |
+
"""
|
| 80 |
+
gamma = w.abs().mean().clamp(min=1e-5)
|
| 81 |
+
w_scaled = w / gamma
|
| 82 |
+
w_quant = w_scaled.clamp(-1, 1).round()
|
| 83 |
+
# STE: detach the rounding, keep gradients on the latent weights
|
| 84 |
+
w_quant = w + (w_quant * gamma - w).detach()
|
| 85 |
+
return w_quant, gamma
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class BitLinear158(nn.Module):
|
| 89 |
+
"""
|
| 90 |
+
1.58-bit Linear Layer from BitNet b1.58.
|
| 91 |
+
|
| 92 |
+
Drop-in replacement for nn.Linear with:
|
| 93 |
+
- Ternary weights {-1, 0, +1} via absmean quantization
|
| 94 |
+
- 8-bit per-token activation quantization
|
| 95 |
+
- Built-in RMSNorm (absorbs the preceding LayerNorm)
|
| 96 |
+
- No bias (following BitNet b1.58 / LLaMA convention)
|
| 97 |
+
- Full-precision latent weights maintained for training (STE)
|
| 98 |
+
|
| 99 |
+
Forward pass:
|
| 100 |
+
1. RMSNorm the input
|
| 101 |
+
2. Quantize activations to 8-bit
|
| 102 |
+
3. Quantize weights to ternary
|
| 103 |
+
4. Matrix multiply (effectively integer addition)
|
| 104 |
+
5. Rescale output
|
| 105 |
+
|
| 106 |
+
During training, gradients flow through quantization via the
|
| 107 |
+
Straight-Through Estimator (STE) — the gradient of round()
|
| 108 |
+
is treated as the identity function.
|
| 109 |
+
|
| 110 |
+
Reference: arXiv:2402.17764
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, in_features: int, out_features: int):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.in_features = in_features
|
| 116 |
+
self.out_features = out_features
|
| 117 |
+
|
| 118 |
+
# Full-precision latent weight (master copy for training)
|
| 119 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features))
|
| 120 |
+
|
| 121 |
+
# Built-in RMSNorm (replaces the preceding LayerNorm)
|
| 122 |
+
self.rms_norm = RMSNorm(in_features)
|
| 123 |
+
|
| 124 |
+
# Initialize weights
|
| 125 |
+
self._init_weights()
|
| 126 |
+
|
| 127 |
+
def _init_weights(self):
|
| 128 |
+
"""Kaiming uniform initialization, same as nn.Linear."""
|
| 129 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 130 |
+
|
| 131 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 132 |
+
"""
|
| 133 |
+
Args:
|
| 134 |
+
x: (batch, seq_len, in_features)
|
| 135 |
+
Returns:
|
| 136 |
+
(batch, seq_len, out_features)
|
| 137 |
+
"""
|
| 138 |
+
# 1. Normalize input (built-in RMSNorm)
|
| 139 |
+
x = self.rms_norm(x)
|
| 140 |
+
|
| 141 |
+
# 2. Quantize activations to 8-bit per-token
|
| 142 |
+
x_q = activation_quant(x)
|
| 143 |
+
|
| 144 |
+
# 3. Quantize weights to ternary {-1, 0, +1}
|
| 145 |
+
w_q, w_scale = weight_quant(self.weight)
|
| 146 |
+
|
| 147 |
+
# 4. Matrix multiply with quantized weights and activations
|
| 148 |
+
# In theory this is integer addition; in practice we use float
|
| 149 |
+
# for autograd compatibility during training
|
| 150 |
+
output = F.linear(x_q, w_q)
|
| 151 |
+
|
| 152 |
+
return output
|
| 153 |
+
|
| 154 |
+
def extra_repr(self) -> str:
|
| 155 |
+
return f"in={self.in_features}, out={self.out_features}, bits=1.58"
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class SwiGLU(nn.Module):
|
| 159 |
+
"""
|
| 160 |
+
SwiGLU activation for Feed-Forward Networks.
|
| 161 |
+
|
| 162 |
+
SwiGLU(x) = (Swish(xW1) ⊙ xV) W2
|
| 163 |
+
|
| 164 |
+
Uses 3 linear projections instead of 2, but the hidden dim
|
| 165 |
+
is typically reduced by 2/3 to keep parameter count similar.
|
| 166 |
+
|
| 167 |
+
When used with BitLinear158, all three projections are ternary.
|
| 168 |
+
|
| 169 |
+
Reference: arXiv:2002.05202 (Shazeer, 2020)
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(self, in_features: int, hidden_features: int = None, use_bitlinear: bool = True):
|
| 173 |
+
super().__init__()
|
| 174 |
+
hidden_features = hidden_features or int(in_features * 8 / 3) # 2/3 of 4x expansion
|
| 175 |
+
# Round to nearest multiple of 8 for efficiency
|
| 176 |
+
hidden_features = ((hidden_features + 7) // 8) * 8
|
| 177 |
+
|
| 178 |
+
Linear = BitLinear158 if use_bitlinear else nn.Linear
|
| 179 |
+
|
| 180 |
+
self.w1 = Linear(in_features, hidden_features) # gate projection
|
| 181 |
+
self.v = Linear(in_features, hidden_features) # value projection
|
| 182 |
+
self.w2 = Linear(hidden_features, in_features) # output projection
|
| 183 |
+
|
| 184 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 185 |
+
return self.w2(F.silu(self.w1(x)) * self.v(x))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# ──── Testing ────────────────────────────────────────────────────
|
| 189 |
+
|
| 190 |
+
if __name__ == "__main__":
|
| 191 |
+
print("Testing BitLinear158 components...")
|
| 192 |
+
|
| 193 |
+
# Test RMSNorm
|
| 194 |
+
norm = RMSNorm(256)
|
| 195 |
+
x = torch.randn(2, 10, 256)
|
| 196 |
+
y = norm(x)
|
| 197 |
+
print(f"RMSNorm: {x.shape} -> {y.shape}, mean={y.mean():.4f}, std={y.std():.4f}")
|
| 198 |
+
|
| 199 |
+
# Test weight quantization
|
| 200 |
+
w = torch.randn(512, 256)
|
| 201 |
+
w_q, scale = weight_quant(w)
|
| 202 |
+
unique = torch.unique(w_q.detach())
|
| 203 |
+
print(f"Weight quant: {w.shape}, unique values: {len(unique)}, scale: {scale:.4f}")
|
| 204 |
+
print(f" Ternary distribution: -1={((w_q.detach().round() == -1).sum().item())}, "
|
| 205 |
+
f"0={((w_q.detach().round() == 0).sum().item())}, "
|
| 206 |
+
f"+1={((w_q.detach().round() == 1).sum().item())}")
|
| 207 |
+
|
| 208 |
+
# Test activation quantization
|
| 209 |
+
a = torch.randn(2, 10, 256)
|
| 210 |
+
a_q = activation_quant(a)
|
| 211 |
+
print(f"Activation quant: range [{a_q.min():.2f}, {a_q.max():.2f}]")
|
| 212 |
+
|
| 213 |
+
# Test BitLinear158
|
| 214 |
+
layer = BitLinear158(256, 512)
|
| 215 |
+
x = torch.randn(2, 10, 256)
|
| 216 |
+
y = layer(x)
|
| 217 |
+
print(f"BitLinear158: {x.shape} -> {y.shape}")
|
| 218 |
+
|
| 219 |
+
# Test gradient flow (STE)
|
| 220 |
+
loss = y.sum()
|
| 221 |
+
loss.backward()
|
| 222 |
+
assert layer.weight.grad is not None, "Gradient did not flow through STE!"
|
| 223 |
+
print(f"STE gradient flow: OK (grad norm: {layer.weight.grad.norm():.4f})")
|
| 224 |
+
|
| 225 |
+
# Test SwiGLU
|
| 226 |
+
swiglu = SwiGLU(256, use_bitlinear=True)
|
| 227 |
+
x = torch.randn(2, 10, 256)
|
| 228 |
+
y = swiglu(x)
|
| 229 |
+
print(f"SwiGLU (BitLinear): {x.shape} -> {y.shape}")
|
| 230 |
+
total = sum(p.numel() for p in swiglu.parameters())
|
| 231 |
+
print(f" SwiGLU params: {total:,}")
|
| 232 |
+
|
| 233 |
+
# Parameter comparison
|
| 234 |
+
ff_standard = nn.Sequential(nn.Linear(256, 512), nn.GELU(), nn.Linear(512, 256))
|
| 235 |
+
ff_params = sum(p.numel() for p in ff_standard.parameters())
|
| 236 |
+
print(f" Standard FFN params: {ff_params:,}")
|
| 237 |
+
print(f" Ratio: {total / ff_params:.2f}x")
|
| 238 |
+
|
| 239 |
+
print("\nAll tests passed! ✓")
|
model/text_encoder.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PixelArtGen — Text Encoder
|
| 3 |
+
|
| 4 |
+
A small transformer encoder that converts text prompts into
|
| 5 |
+
contextual embeddings for conditioning the pixel art decoder.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from typing import List
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TextTokenizer:
|
| 16 |
+
"""Simple word-level tokenizer for text prompts."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, vocab: List[str]):
|
| 19 |
+
self.word2idx = {w: i for i, w in enumerate(vocab)}
|
| 20 |
+
self.idx2word = {i: w for i, w in enumerate(vocab)}
|
| 21 |
+
self.pad_idx = self.word2idx.get("<pad>", 0)
|
| 22 |
+
self.sos_idx = self.word2idx.get("<sos>", 1)
|
| 23 |
+
self.eos_idx = self.word2idx.get("<eos>", 2)
|
| 24 |
+
self.unk_idx = self.word2idx.get("<unk>", 3)
|
| 25 |
+
self.vocab_size = len(vocab)
|
| 26 |
+
|
| 27 |
+
def encode(self, text: str, max_len: int = 32) -> torch.Tensor:
|
| 28 |
+
"""Tokenize and pad a text prompt."""
|
| 29 |
+
words = text.lower().strip().split()
|
| 30 |
+
tokens = [self.sos_idx]
|
| 31 |
+
for w in words:
|
| 32 |
+
tokens.append(self.word2idx.get(w, self.unk_idx))
|
| 33 |
+
tokens.append(self.eos_idx)
|
| 34 |
+
|
| 35 |
+
# Pad or truncate
|
| 36 |
+
if len(tokens) > max_len:
|
| 37 |
+
tokens = tokens[:max_len]
|
| 38 |
+
else:
|
| 39 |
+
tokens += [self.pad_idx] * (max_len - len(tokens))
|
| 40 |
+
|
| 41 |
+
return torch.tensor(tokens, dtype=torch.long)
|
| 42 |
+
|
| 43 |
+
def encode_batch(self, texts: List[str], max_len: int = 32) -> torch.Tensor:
|
| 44 |
+
"""Encode a batch of text prompts."""
|
| 45 |
+
return torch.stack([self.encode(t, max_len) for t in texts])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class TextEncoder(nn.Module):
|
| 49 |
+
"""
|
| 50 |
+
Small transformer encoder for text prompts.
|
| 51 |
+
|
| 52 |
+
Architecture:
|
| 53 |
+
- Word embeddings + sinusoidal positional encoding
|
| 54 |
+
- N transformer encoder layers with multi-head attention
|
| 55 |
+
- Output: sequence of contextual embeddings (batch, seq_len, d_model)
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
vocab_size: int,
|
| 61 |
+
d_model: int = 256,
|
| 62 |
+
nhead: int = 4,
|
| 63 |
+
num_layers: int = 3,
|
| 64 |
+
dim_feedforward: int = 512,
|
| 65 |
+
max_seq_len: int = 32,
|
| 66 |
+
dropout: float = 0.1,
|
| 67 |
+
):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.d_model = d_model
|
| 70 |
+
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
|
| 71 |
+
self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len)
|
| 72 |
+
self.dropout = nn.Dropout(dropout)
|
| 73 |
+
|
| 74 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 75 |
+
d_model=d_model,
|
| 76 |
+
nhead=nhead,
|
| 77 |
+
dim_feedforward=dim_feedforward,
|
| 78 |
+
dropout=dropout,
|
| 79 |
+
batch_first=True,
|
| 80 |
+
norm_first=True,
|
| 81 |
+
)
|
| 82 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 83 |
+
self.norm = nn.LayerNorm(d_model)
|
| 84 |
+
|
| 85 |
+
def forward(self, text_tokens: torch.Tensor) -> torch.Tensor:
|
| 86 |
+
"""
|
| 87 |
+
Args:
|
| 88 |
+
text_tokens: (batch, seq_len) long tensor of word indices
|
| 89 |
+
Returns:
|
| 90 |
+
(batch, seq_len, d_model) contextual embeddings
|
| 91 |
+
"""
|
| 92 |
+
# Create padding mask (True = ignore)
|
| 93 |
+
pad_mask = (text_tokens == 0) # pad_idx = 0
|
| 94 |
+
|
| 95 |
+
# Embed + positional encode
|
| 96 |
+
x = self.embedding(text_tokens) * math.sqrt(self.d_model)
|
| 97 |
+
x = self.pos_encoding(x)
|
| 98 |
+
x = self.dropout(x)
|
| 99 |
+
|
| 100 |
+
# Transformer encode
|
| 101 |
+
x = self.transformer(x, src_key_padding_mask=pad_mask)
|
| 102 |
+
x = self.norm(x)
|
| 103 |
+
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class SinusoidalPositionalEncoding(nn.Module):
|
| 108 |
+
"""Standard sinusoidal positional encoding."""
|
| 109 |
+
|
| 110 |
+
def __init__(self, d_model: int, max_len: int = 512):
|
| 111 |
+
super().__init__()
|
| 112 |
+
pe = torch.zeros(max_len, d_model)
|
| 113 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 114 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 115 |
+
|
| 116 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 117 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 118 |
+
pe = pe.unsqueeze(0) # (1, max_len, d_model)
|
| 119 |
+
self.register_buffer("pe", pe)
|
| 120 |
+
|
| 121 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
return x + self.pe[:, :x.size(1)]
|
model/tokenizer.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PixelArtGen — Color Palette Tokenizer
|
| 3 |
+
|
| 4 |
+
Converts 32×32 RGB pixel art images into sequences of palette indices
|
| 5 |
+
and back. This is the "vocabulary" for the pixel language model.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PaletteTokenizer:
|
| 14 |
+
"""
|
| 15 |
+
Maps RGB pixels to/from a fixed palette of N colors.
|
| 16 |
+
Each pixel becomes a token index ∈ [0, palette_size).
|
| 17 |
+
|
| 18 |
+
Special tokens:
|
| 19 |
+
palette_size = <sos> (start of sequence)
|
| 20 |
+
palette_size + 1 = <eos> (end of sequence)
|
| 21 |
+
palette_size + 2 = <pad> (padding)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, palette_path: str = None, palette: np.ndarray = None, palette_size: int = 256):
|
| 25 |
+
if palette is not None:
|
| 26 |
+
self.palette = palette.astype(np.float32)
|
| 27 |
+
elif palette_path is not None:
|
| 28 |
+
self.palette = np.load(palette_path).astype(np.float32)
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError("Must provide palette_path or palette array")
|
| 31 |
+
|
| 32 |
+
self.palette_size = len(self.palette)
|
| 33 |
+
self.sos_token = self.palette_size
|
| 34 |
+
self.eos_token = self.palette_size + 1
|
| 35 |
+
self.pad_token = self.palette_size + 2
|
| 36 |
+
self.vocab_size = self.palette_size + 3 # colors + sos + eos + pad
|
| 37 |
+
|
| 38 |
+
def rgb_to_index(self, rgb: np.ndarray) -> int:
|
| 39 |
+
"""Find the closest palette color for an RGB value."""
|
| 40 |
+
distances = np.sum((self.palette - rgb.astype(np.float32)) ** 2, axis=1)
|
| 41 |
+
return int(np.argmin(distances))
|
| 42 |
+
|
| 43 |
+
def encode_image(self, img_array: np.ndarray) -> list:
|
| 44 |
+
"""
|
| 45 |
+
Encode a 32×32×3 RGB image into a flat sequence of palette indices.
|
| 46 |
+
Returns: [sos, p0, p1, ..., p1023, eos] (1026 tokens)
|
| 47 |
+
"""
|
| 48 |
+
h, w, c = img_array.shape
|
| 49 |
+
assert h == 32 and w == 32 and c == 3, f"Expected 32×32×3, got {img_array.shape}"
|
| 50 |
+
|
| 51 |
+
tokens = [self.sos_token]
|
| 52 |
+
for y in range(h):
|
| 53 |
+
for x in range(w):
|
| 54 |
+
pixel = img_array[y, x]
|
| 55 |
+
idx = self.rgb_to_index(pixel)
|
| 56 |
+
tokens.append(idx)
|
| 57 |
+
tokens.append(self.eos_token)
|
| 58 |
+
return tokens
|
| 59 |
+
|
| 60 |
+
def encode_image_fast(self, img_array: np.ndarray) -> list:
|
| 61 |
+
"""
|
| 62 |
+
Vectorized encoding — much faster than pixel-by-pixel.
|
| 63 |
+
"""
|
| 64 |
+
h, w, c = img_array.shape
|
| 65 |
+
pixels = img_array.reshape(-1, 3).astype(np.float32) # (1024, 3)
|
| 66 |
+
|
| 67 |
+
# Compute distances to all palette colors at once
|
| 68 |
+
# pixels: (1024, 3), palette: (N, 3)
|
| 69 |
+
diff = pixels[:, None, :] - self.palette[None, :, :] # (1024, N, 3)
|
| 70 |
+
distances = np.sum(diff ** 2, axis=2) # (1024, N)
|
| 71 |
+
indices = np.argmin(distances, axis=1) # (1024,)
|
| 72 |
+
|
| 73 |
+
tokens = [self.sos_token] + indices.tolist() + [self.eos_token]
|
| 74 |
+
return tokens
|
| 75 |
+
|
| 76 |
+
def decode_tokens(self, tokens: list) -> np.ndarray:
|
| 77 |
+
"""
|
| 78 |
+
Decode a sequence of palette indices back to a 32×32×3 RGB image.
|
| 79 |
+
Strips sos/eos/pad tokens.
|
| 80 |
+
"""
|
| 81 |
+
# Filter special tokens
|
| 82 |
+
pixel_tokens = [t for t in tokens if t < self.palette_size]
|
| 83 |
+
|
| 84 |
+
# Pad or truncate to exactly 1024 pixels
|
| 85 |
+
if len(pixel_tokens) < 1024:
|
| 86 |
+
pixel_tokens += [0] * (1024 - len(pixel_tokens))
|
| 87 |
+
pixel_tokens = pixel_tokens[:1024]
|
| 88 |
+
|
| 89 |
+
img = np.zeros((1024, 3), dtype=np.uint8)
|
| 90 |
+
for i, idx in enumerate(pixel_tokens):
|
| 91 |
+
idx = min(idx, self.palette_size - 1)
|
| 92 |
+
img[i] = self.palette[idx].astype(np.uint8)
|
| 93 |
+
|
| 94 |
+
return img.reshape(32, 32, 3)
|
| 95 |
+
|
| 96 |
+
def tokens_to_tensor(self, tokens: list, max_len: int = 1026) -> torch.Tensor:
|
| 97 |
+
"""Convert token list to padded tensor."""
|
| 98 |
+
if len(tokens) > max_len:
|
| 99 |
+
tokens = tokens[:max_len]
|
| 100 |
+
else:
|
| 101 |
+
tokens = tokens + [self.pad_token] * (max_len - len(tokens))
|
| 102 |
+
return torch.tensor(tokens, dtype=torch.long)
|
| 103 |
+
|
| 104 |
+
def get_palette_tensor(self) -> torch.Tensor:
|
| 105 |
+
"""Return the palette as a (palette_size, 3) float32 tensor."""
|
| 106 |
+
return torch.tensor(self.palette, dtype=torch.float32)
|