BlakePeavy commited on
Commit
72e872c
·
verified ·
1 Parent(s): 58a5e05

Upload BitPixelLM model artifacts

Browse files
Files changed (10) hide show
  1. README.md +76 -0
  2. app.py +310 -0
  3. best.pt +3 -0
  4. config.json +7 -0
  5. generate.py +196 -0
  6. model/__init__.py +17 -0
  7. model/bit_pixel_decoder.py +577 -0
  8. model/bitlinear.py +239 -0
  9. model/text_encoder.py +122 -0
  10. model/tokenizer.py +106 -0
README.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BitPixelLM
2
+
3
+ BitPixelLM is a text-to-pixel-art language model that generates 32x32 images from prompts like `a red pixel art sword`.
4
+
5
+ It uses a BitNet b1.58-style ternary decoder (`-1, 0, +1`) with a lightweight text encoder.
6
+
7
+ ## Current Model Snapshot
8
+
9
+ - Model name: **BitPixelLM**
10
+ - Architecture: 3-layer text encoder + 6-layer BitPixelLM decoder
11
+ - Parameters: ~7.3M
12
+ - Dataset (v3): 23,648 synthetic pixel-art samples
13
+ - Vocab: 222 words
14
+ - Best validation loss (v3): ~0.4015
15
+
16
+ ## Project Layout
17
+
18
+ - `model/bit_pixel_decoder.py` — BitPixelLM model
19
+ - `train_bitnet.py` — training pipeline
20
+ - `generate.py` — CLI generation
21
+ - `app.py` — Gradio app
22
+ - `data/generate_v3.py` — v3 dataset generator
23
+ - `PixelArtGen_Colab.ipynb` — Colab training notebook
24
+
25
+ ## Run Locally
26
+
27
+ 1. Ensure Python 3.9 + CUDA-enabled PyTorch.
28
+ 2. Place data in `D:\PixelArtGen_Data\processed`:
29
+ - `tokens.npy`, `labels.json`, `vocab.json`, `palette_256.npy`
30
+ 3. Train:
31
+
32
+ ```bash
33
+ python train_bitnet.py --epochs 60 --batch-size 32 --lr 5e-4
34
+ ```
35
+
36
+ 4. Launch app:
37
+
38
+ ```bash
39
+ python app.py
40
+ ```
41
+
42
+ ## Publish to Hugging Face
43
+
44
+ This repo includes `publish_hf.py` for one-step upload.
45
+
46
+ ### Required
47
+
48
+ - Hugging Face token with write access (`HF_TOKEN`)
49
+ - `huggingface_hub` installed
50
+
51
+ ### Command
52
+
53
+ ```bash
54
+ pip install huggingface_hub
55
+ python publish_hf.py --repo-id YOUR_USERNAME/BitPixelLM --token $HF_TOKEN
56
+ ```
57
+
58
+ On Windows PowerShell:
59
+
60
+ ```powershell
61
+ $env:HF_TOKEN = "hf_xxx"
62
+ python publish_hf.py --repo-id YOUR_USERNAME/BitPixelLM --token $env:HF_TOKEN
63
+ ```
64
+
65
+ This uploads:
66
+
67
+ - `checkpoints_bit/best.pt`
68
+ - `model/` Python files
69
+ - `generate.py`
70
+ - `app.py`
71
+ - `README.md` (model card / usage overview)
72
+
73
+ ## Notes
74
+
75
+ - The active production model is **BitPixelLM**.
76
+ - Legacy FP32 `PixelLM` artifacts remain in the repo only for historical reference.
app.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PixelArtGen — Gradio Web UI
3
+
4
+ Interactive UI to generate pixel art from text prompts using
5
+ BitPixelLM — a 1.58-bit ternary transformer (BitNet b1.58).
6
+
7
+ Launch:
8
+ python app.py
9
+ Then open http://localhost:7860 in your browser.
10
+ """
11
+
12
+ import sys
13
+ import json
14
+ import torch
15
+ import numpy as np
16
+ import gradio as gr
17
+ from pathlib import Path
18
+ from PIL import Image
19
+
20
+ sys.path.insert(0, str(Path(__file__).parent))
21
+
22
+ from model.tokenizer import PaletteTokenizer
23
+ from model.text_encoder import TextTokenizer, TextEncoder
24
+ from model.bit_pixel_decoder import BitPixelLMDecoder, BitPixelLM
25
+
26
+ # ─── Config ──────────────────────────────────────────────────────
27
+ DATA_DIR = Path(r"D:\PixelArtGen_Data\processed")
28
+ CHECKPOINT_PATH = Path("checkpoints_bit/best.pt")
29
+
30
+ # ─── Global state (loaded once) ─────────────────────────────────
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ model = None
33
+ palette_tok = None
34
+ text_tok = None
35
+
36
+
37
+ def load_tokenizers():
38
+ """Load shared tokenizers."""
39
+ global palette_tok, text_tok
40
+ palette_tok = PaletteTokenizer(palette_path=str(DATA_DIR / "palette_256.npy"))
41
+ with open(DATA_DIR / "vocab.json") as f:
42
+ vocab = json.load(f)
43
+ text_tok = TextTokenizer(vocab)
44
+
45
+
46
+ def load_model():
47
+ """Load the BitPixelLM model from checkpoint."""
48
+ global model
49
+ if model is not None:
50
+ return model
51
+
52
+ if not CHECKPOINT_PATH.exists():
53
+ raise FileNotFoundError(
54
+ f"Checkpoint not found: {CHECKPOINT_PATH}\n"
55
+ "BitPixelLM is still training — check back once training completes."
56
+ )
57
+
58
+ checkpoint = torch.load(str(CHECKPOINT_PATH), map_location=device, weights_only=False)
59
+ model_args = checkpoint.get("args", {})
60
+
61
+ d_model = model_args.get("d_model", 256)
62
+ nhead = model_args.get("nhead", 8)
63
+ text_layers = model_args.get("text_layers", 3)
64
+ pixel_layers = model_args.get("pixel_layers", 6)
65
+ dim_ff = model_args.get("dim_ff", 512)
66
+ dropout = model_args.get("dropout", 0.1)
67
+ max_text_len = model_args.get("max_text_len", 32)
68
+
69
+ text_encoder = TextEncoder(
70
+ vocab_size=text_tok.vocab_size,
71
+ d_model=d_model,
72
+ nhead=nhead,
73
+ num_layers=text_layers,
74
+ dim_feedforward=dim_ff,
75
+ max_seq_len=max_text_len,
76
+ dropout=dropout,
77
+ )
78
+
79
+ pixel_decoder = BitPixelLMDecoder(
80
+ vocab_size=palette_tok.vocab_size,
81
+ d_model=d_model,
82
+ nhead=nhead,
83
+ num_layers=pixel_layers,
84
+ dim_feedforward=dim_ff,
85
+ img_size=32,
86
+ dropout=dropout,
87
+ )
88
+ m = BitPixelLM(text_encoder, pixel_decoder).to(device)
89
+
90
+ m.load_state_dict(checkpoint["model_state_dict"])
91
+ m.eval()
92
+ model = m
93
+ return model
94
+
95
+
96
+ def generate(
97
+ prompt: str,
98
+ temperature: float,
99
+ top_k: int,
100
+ top_p: float,
101
+ num_samples: int,
102
+ scale: int,
103
+ ):
104
+ """Generate pixel art from a text prompt."""
105
+ if not prompt.strip():
106
+ raise gr.Error("Please enter a prompt.")
107
+
108
+ if model is None:
109
+ raise gr.Error(
110
+ "BitPixelLM is not loaded yet. "
111
+ "It may still be training — check back once training completes."
112
+ )
113
+
114
+ text_tokens = text_tok.encode(prompt).unsqueeze(0).to(device)
115
+
116
+ # Warn about unknown words (still generates, but quality may suffer)
117
+ words = prompt.lower().strip().split()
118
+ unknown = [w for w in words if w not in text_tok.word2idx and w not in ("<pad>", "<sos>", "<eos>", "<unk>")]
119
+
120
+ images = []
121
+ try:
122
+ for _ in range(int(num_samples)):
123
+ with torch.no_grad():
124
+ generated_tokens = model.generate(
125
+ text_tokens,
126
+ sos_token=palette_tok.sos_token,
127
+ eos_token=palette_tok.eos_token,
128
+ temperature=temperature,
129
+ top_k=top_k,
130
+ top_p=top_p,
131
+ )
132
+
133
+ token_list = generated_tokens[0].cpu().tolist()
134
+ img_array = palette_tok.decode_tokens(token_list)
135
+ img = Image.fromarray(img_array, "RGB")
136
+
137
+ # Upscale with nearest-neighbor for crisp pixels
138
+ s = int(scale)
139
+ if s > 1:
140
+ img = img.resize((32 * s, 32 * s), Image.NEAREST)
141
+
142
+ images.append(img)
143
+ except Exception as e:
144
+ raise gr.Error(f"Generation failed: {e}")
145
+
146
+ if unknown:
147
+ gr.Warning(
148
+ f"Unknown words treated as <unk>: {', '.join(unknown)}. "
149
+ f"Try using words from the vocabulary list below."
150
+ )
151
+
152
+ return images
153
+
154
+
155
+ # ─── Build UI ─────────────────────���──────────────────────────────
156
+
157
+ # Load vocabulary dynamically from processed data
158
+ def _load_vocab_words():
159
+ try:
160
+ with open(DATA_DIR / "vocab.json") as f:
161
+ vocab = json.load(f)
162
+ return sorted([w for w in vocab if not w.startswith("<")])
163
+ except Exception:
164
+ return ["pixel", "art", "sword", "red", "blue", "green"]
165
+
166
+ VOCAB_WORDS = _load_vocab_words()
167
+
168
+ EXAMPLE_PROMPTS = [
169
+ "a red pixel art sword",
170
+ "a green pixel art dragon",
171
+ "a purple pixel art crystal",
172
+ "a blue pixel art knight",
173
+ "a gold pixel art castle",
174
+ "a red pixel art phoenix",
175
+ "a dark pixel art skeleton",
176
+ "a teal pixel art wizard",
177
+ "a silver pixel art robot",
178
+ "a orange pixel art fox",
179
+ ]
180
+
181
+
182
+ def build_ui():
183
+ with gr.Blocks(
184
+ title="PixelArtGen",
185
+ theme=gr.themes.Soft(primary_hue="purple"),
186
+ css="""
187
+ .gallery-item img { image-rendering: pixelated !important; }
188
+ .output-gallery img { image-rendering: pixelated !important; }
189
+ #gallery img { image-rendering: pixelated !important; }
190
+ """,
191
+ ) as app:
192
+ gr.Markdown(
193
+ """
194
+ # PixelArtGen
195
+ ### Generate 32x32 pixel art from text prompts
196
+
197
+ Powered by **BitPixelLM** — a custom 1.58-bit ternary transformer built from scratch
198
+ using BitNet b1.58 with RMSNorm, SwiGLU, and 2D positional encoding.
199
+ 7.3M parameters (75% ternary weights at 1.58 bits per weight).
200
+ """
201
+ )
202
+
203
+ with gr.Row():
204
+ with gr.Column(scale=1):
205
+ prompt = gr.Textbox(
206
+ label="Prompt",
207
+ placeholder="a red pixel art sword",
208
+ lines=2,
209
+ )
210
+ with gr.Row():
211
+ generate_btn = gr.Button("Generate", variant="primary", scale=2)
212
+ num_samples = gr.Slider(1, 8, value=4, step=1, label="Samples")
213
+
214
+ with gr.Accordion("Advanced Settings", open=False):
215
+ temperature = gr.Slider(
216
+ 0.1, 2.0, value=0.8, step=0.05,
217
+ label="Temperature",
218
+ info="Lower = more deterministic, higher = more creative"
219
+ )
220
+ top_k = gr.Slider(
221
+ 0, 256, value=40, step=1,
222
+ label="Top-K",
223
+ info="0 = disabled. Limits sampling to top K tokens."
224
+ )
225
+ top_p = gr.Slider(
226
+ 0.1, 1.0, value=0.9, step=0.05,
227
+ label="Top-P (Nucleus)",
228
+ info="Cumulative probability threshold for sampling."
229
+ )
230
+ scale = gr.Slider(
231
+ 1, 16, value=8, step=1,
232
+ label="Upscale Factor",
233
+ info="8x = 256x256, 16x = 512x512"
234
+ )
235
+
236
+ gr.Markdown(
237
+ f"**Known vocabulary:** {', '.join(VOCAB_WORDS)}"
238
+ )
239
+
240
+ with gr.Column(scale=2):
241
+ gallery = gr.Gallery(
242
+ label="Generated Pixel Art",
243
+ columns=4,
244
+ rows=2,
245
+ height=520,
246
+ object_fit="contain",
247
+ elem_id="gallery",
248
+ )
249
+
250
+ gr.Markdown("### Examples")
251
+ gr.Examples(
252
+ examples=EXAMPLE_PROMPTS,
253
+ inputs=[prompt],
254
+ label="Click to try",
255
+ )
256
+
257
+ gr.Markdown(
258
+ """
259
+ ---
260
+ **Architecture:**
261
+ BitPixelLM treats pixel art generation as language modeling — each pixel is a token from a 256-color palette,
262
+ generated left-to-right, top-to-bottom via a causal transformer with 2D positional encoding and cross-attention to text.
263
+ Uses 1.58-bit ternary weights (BitNet b1.58) with RMSNorm and SwiGLU for extreme parameter efficiency.
264
+ """
265
+ )
266
+
267
+ # Wire up the generate button
268
+ generate_btn.click(
269
+ fn=generate,
270
+ inputs=[prompt, temperature, top_k, top_p, num_samples, scale],
271
+ outputs=gallery,
272
+ )
273
+
274
+ # Also generate on Enter
275
+ prompt.submit(
276
+ fn=generate,
277
+ inputs=[prompt, temperature, top_k, top_p, num_samples, scale],
278
+ outputs=gallery,
279
+ )
280
+
281
+ return app
282
+
283
+
284
+ # ─── Main ────────────────────────────────────────────────────────
285
+ if __name__ == "__main__":
286
+ print("Loading tokenizers...")
287
+ load_tokenizers()
288
+ print(f" Palette: {palette_tok.vocab_size} tokens")
289
+ print(f" Text: {text_tok.vocab_size} words")
290
+ print(f" Device: {device}")
291
+
292
+ # Load BitPixelLM
293
+ print(f"Loading BitPixelLM from {CHECKPOINT_PATH}...")
294
+ try:
295
+ load_model()
296
+ print(f" BitPixelLM loaded successfully.")
297
+ except FileNotFoundError as e:
298
+ print(f" {e}")
299
+ print(f" UI will launch but generation will be unavailable until training completes.")
300
+ except Exception as e:
301
+ print(f" Failed to load BitPixelLM: {e}")
302
+
303
+ print("\nLaunching UI...")
304
+ app = build_ui()
305
+ app.launch(
306
+ server_name="0.0.0.0",
307
+ server_port=7860,
308
+ share=False,
309
+ inbrowser=True,
310
+ )
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37ceef8a7d844445be4bc5730bcd683d1512aff084a6e872634b3184c58f2464
3
+ size 88732053
config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "BitPixelLM",
3
+ "architecture": "BitNet-b1.58-style autoregressive decoder",
4
+ "image_size": 32,
5
+ "task": "text-to-image (pixel art)",
6
+ "checkpoint_file": "best.pt"
7
+ }
generate.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PixelArtGen — Generate pixel art from text prompts.
3
+
4
+ Usage:
5
+ python generate.py --prompt "a red pixel art sword" --output output.png
6
+ python generate.py --prompt "a blue pixel art heart" --output heart.png --temperature 0.7
7
+ python generate.py --batch-prompts prompts.txt --output-dir outputs/
8
+ """
9
+
10
+ import os
11
+ import sys
12
+ import json
13
+ import argparse
14
+ import numpy as np
15
+ import torch
16
+ from pathlib import Path
17
+ from PIL import Image
18
+
19
+ sys.path.insert(0, str(Path(__file__).parent))
20
+
21
+ from model.tokenizer import PaletteTokenizer
22
+ from model.text_encoder import TextTokenizer, TextEncoder
23
+ from model.pixel_decoder import PixelLMDecoder, PixelLM
24
+
25
+
26
+ def load_model(checkpoint_path: str, data_dir: str, device: torch.device):
27
+ """Load a trained PixelLM model from checkpoint."""
28
+ data_dir = Path(data_dir)
29
+
30
+ # Load checkpoint
31
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
32
+ model_args = checkpoint.get("args", {})
33
+
34
+ # Load tokenizers
35
+ palette_tok = PaletteTokenizer(palette_path=str(data_dir / "palette_256.npy"))
36
+
37
+ with open(data_dir / "vocab.json") as f:
38
+ vocab = json.load(f)
39
+ text_tok = TextTokenizer(vocab)
40
+
41
+ # Rebuild model
42
+ d_model = model_args.get("d_model", 256)
43
+ nhead = model_args.get("nhead", 8)
44
+ text_layers = model_args.get("text_layers", 3)
45
+ pixel_layers = model_args.get("pixel_layers", 6)
46
+ dim_ff = model_args.get("dim_ff", 512)
47
+ dropout = model_args.get("dropout", 0.1)
48
+ max_text_len = model_args.get("max_text_len", 32)
49
+
50
+ text_encoder = TextEncoder(
51
+ vocab_size=text_tok.vocab_size,
52
+ d_model=d_model,
53
+ nhead=nhead,
54
+ num_layers=text_layers,
55
+ dim_feedforward=dim_ff,
56
+ max_seq_len=max_text_len,
57
+ dropout=dropout,
58
+ )
59
+
60
+ pixel_decoder = PixelLMDecoder(
61
+ vocab_size=palette_tok.vocab_size,
62
+ d_model=d_model,
63
+ nhead=nhead,
64
+ num_layers=pixel_layers,
65
+ dim_feedforward=dim_ff,
66
+ img_size=32,
67
+ dropout=dropout,
68
+ )
69
+
70
+ model = PixelLM(text_encoder, pixel_decoder).to(device)
71
+ model.load_state_dict(checkpoint["model_state_dict"])
72
+ model.eval()
73
+
74
+ return model, palette_tok, text_tok
75
+
76
+
77
+ def generate_pixel_art(
78
+ model: PixelLM,
79
+ palette_tok: PaletteTokenizer,
80
+ text_tok: TextTokenizer,
81
+ prompt: str,
82
+ device: torch.device,
83
+ temperature: float = 0.8,
84
+ top_k: int = 40,
85
+ top_p: float = 0.9,
86
+ scale: int = 8,
87
+ ) -> Image.Image:
88
+ """
89
+ Generate a 32×32 pixel art image from a text prompt.
90
+
91
+ Args:
92
+ model: Trained PixelLM model
93
+ palette_tok: Color palette tokenizer
94
+ text_tok: Text tokenizer
95
+ prompt: Text description
96
+ device: torch device
97
+ temperature: Sampling temperature (lower = more deterministic)
98
+ top_k: Top-k filtering
99
+ top_p: Nucleus sampling threshold
100
+ scale: Upscale factor for display (8 = 256×256 output)
101
+ Returns:
102
+ PIL Image (32*scale × 32*scale)
103
+ """
104
+ # Tokenize prompt
105
+ text_tokens = text_tok.encode(prompt).unsqueeze(0).to(device)
106
+
107
+ # Generate
108
+ with torch.no_grad():
109
+ generated_tokens = model.generate(
110
+ text_tokens,
111
+ sos_token=palette_tok.sos_token,
112
+ eos_token=palette_tok.eos_token,
113
+ temperature=temperature,
114
+ top_k=top_k,
115
+ top_p=top_p,
116
+ )
117
+
118
+ # Decode to image
119
+ token_list = generated_tokens[0].cpu().tolist()
120
+ img_array = palette_tok.decode_tokens(token_list)
121
+ img = Image.fromarray(img_array, "RGB")
122
+
123
+ # Upscale with nearest-neighbor (pixel art style)
124
+ if scale > 1:
125
+ img = img.resize((32 * scale, 32 * scale), Image.NEAREST)
126
+
127
+ return img
128
+
129
+
130
+ def main():
131
+ parser = argparse.ArgumentParser(description="Generate pixel art from text")
132
+ parser.add_argument("--prompt", type=str, help="Text prompt")
133
+ parser.add_argument("--output", type=str, default="output.png", help="Output file")
134
+ parser.add_argument("--checkpoint", type=str, default="checkpoints/best.pt")
135
+ parser.add_argument("--data-dir", type=str, default=r"D:\PixelArtGen_Data\processed")
136
+ parser.add_argument("--temperature", type=float, default=0.8)
137
+ parser.add_argument("--top-k", type=int, default=40)
138
+ parser.add_argument("--top-p", type=float, default=0.9)
139
+ parser.add_argument("--scale", type=int, default=8, help="Upscale factor")
140
+ parser.add_argument("--num-samples", type=int, default=1, help="Number of images to generate")
141
+ parser.add_argument("--batch-prompts", type=str, help="File with prompts (one per line)")
142
+ parser.add_argument("--output-dir", type=str, default="outputs")
143
+
144
+ args = parser.parse_args()
145
+
146
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147
+ print(f"Device: {device}")
148
+
149
+ # Load model
150
+ print(f"Loading model from {args.checkpoint}...")
151
+ model, palette_tok, text_tok = load_model(args.checkpoint, args.data_dir, device)
152
+ print(f" Model: {model.count_parameters():,} parameters")
153
+
154
+ # Collect prompts
155
+ if args.batch_prompts:
156
+ with open(args.batch_prompts) as f:
157
+ prompts = [line.strip() for line in f if line.strip()]
158
+ elif args.prompt:
159
+ prompts = [args.prompt]
160
+ else:
161
+ prompts = [
162
+ "a red pixel art sword",
163
+ "a blue pixel art heart",
164
+ "a green pixel art tree",
165
+ "a purple pixel art gem",
166
+ ]
167
+
168
+ # Generate
169
+ output_dir = Path(args.output_dir)
170
+ output_dir.mkdir(parents=True, exist_ok=True)
171
+
172
+ for i, prompt in enumerate(prompts):
173
+ print(f"\nGenerating: \"{prompt}\"")
174
+ for j in range(args.num_samples):
175
+ img = generate_pixel_art(
176
+ model, palette_tok, text_tok, prompt, device,
177
+ temperature=args.temperature,
178
+ top_k=args.top_k,
179
+ top_p=args.top_p,
180
+ scale=args.scale,
181
+ )
182
+
183
+ if len(prompts) == 1 and args.num_samples == 1:
184
+ out_path = args.output
185
+ else:
186
+ safe_name = prompt.replace(" ", "_")[:30]
187
+ out_path = output_dir / f"{safe_name}_{j}.png"
188
+
189
+ img.save(str(out_path))
190
+ print(f" Saved: {out_path}")
191
+
192
+ print("\nDone!")
193
+
194
+
195
+ if __name__ == "__main__":
196
+ main()
model/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PixelArtGen model package."""
2
+
3
+ from .tokenizer import PaletteTokenizer
4
+ from .text_encoder import TextTokenizer, TextEncoder
5
+ from .bitlinear import BitLinear158, RMSNorm, SwiGLU
6
+ from .bit_pixel_decoder import BitPixelLMDecoder, BitPixelLM
7
+
8
+ __all__ = [
9
+ "PaletteTokenizer",
10
+ "TextTokenizer",
11
+ "TextEncoder",
12
+ "BitLinear158",
13
+ "RMSNorm",
14
+ "SwiGLU",
15
+ "BitPixelLMDecoder",
16
+ "BitPixelLM",
17
+ ]
model/bit_pixel_decoder.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PixelArtGen — BitPixelLM Decoder (1.58-bit)
3
+
4
+ A ternary-weight variant of our PixelLM decoder, implementing BitNet b1.58.
5
+ Replaces nn.Linear layers with BitLinear158 (ternary weights {-1, 0, +1})
6
+ and uses modern LLaMA-alike components (RMSNorm, SwiGLU, no biases).
7
+
8
+ Key differences from the standard PixelLM decoder:
9
+ - BitLinear158 layers with built-in RMSNorm (replaces nn.Linear + LayerNorm)
10
+ - SwiGLU FFN activation (replaces GELU)
11
+ - No biases anywhere
12
+ - Token embeddings and output head remain in full precision
13
+ - 2D positional encoding preserved (our unique contribution)
14
+
15
+ References:
16
+ - "The Era of 1-bit LLMs" (Ma et al., 2024) — arXiv:2402.17764
17
+ - "BitNet" (Wang et al., 2023) — arXiv:2310.11453
18
+ - "GLU Variants Improve Transformer" (Shazeer, 2020) — arXiv:2002.05202
19
+ - "RMSNorm" (Zhang & Sennrich, 2019) — arXiv:1910.07467
20
+ """
21
+
22
+ import math
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from typing import Optional
27
+
28
+ from model.bitlinear import BitLinear158, RMSNorm, SwiGLU
29
+
30
+
31
+ # ── Shared components (self-contained, no dependency on pixel_decoder.py) ──
32
+
33
+ class PixelPositionalEncoding2D(nn.Module):
34
+ """
35
+ 2D positional encoding for pixel sequences.
36
+
37
+ Instead of treating pixel positions as flat indices 0..1023,
38
+ we encode them as (row, col) pairs with separate learned embeddings.
39
+ This gives the model explicit 2D spatial structure.
40
+
41
+ Also includes a special position embedding for <sos> and <eos> tokens.
42
+ """
43
+
44
+ def __init__(self, d_model: int, img_size: int = 32):
45
+ super().__init__()
46
+ self.img_size = img_size
47
+ self.d_model = d_model
48
+
49
+ # Separate row and column embeddings
50
+ self.row_embed = nn.Embedding(img_size, d_model // 2)
51
+ self.col_embed = nn.Embedding(img_size, d_model // 2)
52
+
53
+ # Special position for sos/eos tokens
54
+ self.special_pos = nn.Embedding(2, d_model) # 0=sos, 1=eos
55
+
56
+ # Learnable scale
57
+ self.scale = nn.Parameter(torch.ones(1))
58
+
59
+ def forward(self, seq_len: int, device: torch.device) -> torch.Tensor:
60
+ """
61
+ Generate positional encodings for a sequence of length seq_len.
62
+ Sequence layout: [sos, pixel_0, pixel_1, ..., pixel_1023, eos]
63
+
64
+ Returns: (1, seq_len, d_model)
65
+ """
66
+ positions = torch.zeros(1, seq_len, self.d_model, device=device)
67
+
68
+ # SOS position
69
+ positions[:, 0, :] = self.special_pos(torch.tensor([0], device=device))
70
+
71
+ # Pixel positions (indices 1..1024)
72
+ num_pixels = min(seq_len - 1, self.img_size * self.img_size)
73
+ if num_pixels > 0:
74
+ pixel_indices = torch.arange(num_pixels, device=device)
75
+ rows = pixel_indices // self.img_size
76
+ cols = pixel_indices % self.img_size
77
+
78
+ row_emb = self.row_embed(rows) # (num_pixels, d_model//2)
79
+ col_emb = self.col_embed(cols) # (num_pixels, d_model//2)
80
+ pixel_pos = torch.cat([row_emb, col_emb], dim=-1) # (num_pixels, d_model)
81
+ positions[:, 1:1 + num_pixels, :] = pixel_pos.unsqueeze(0)
82
+
83
+ # EOS position (if present)
84
+ if seq_len > self.img_size * self.img_size + 1:
85
+ positions[:, -1, :] = self.special_pos(torch.tensor([1], device=device))
86
+
87
+ return positions * self.scale
88
+
89
+
90
+ class PaletteOutputHead(nn.Module):
91
+ """
92
+ Palette-aware output prediction.
93
+
94
+ Instead of a flat linear(d_model -> vocab_size) layer, we compute
95
+ output logits via scaled dot-product attention between the decoder
96
+ hidden states and a set of learned palette key vectors.
97
+
98
+ Each palette color has a key embedding initialized from its RGB values.
99
+ This gives the model an inductive bias toward understanding color relationships.
100
+ """
101
+
102
+ def __init__(self, d_model: int, palette_size: int, num_special_tokens: int = 3):
103
+ super().__init__()
104
+ self.total_vocab = palette_size + num_special_tokens
105
+ self.d_model = d_model
106
+
107
+ # Learned palette keys (will be initialized from RGB values)
108
+ self.palette_keys = nn.Parameter(torch.randn(self.total_vocab, d_model))
109
+
110
+ # Query projection for hidden states
111
+ self.query_proj = nn.Linear(d_model, d_model)
112
+
113
+ # Temperature parameter for controlling sharpness
114
+ self.temperature = nn.Parameter(torch.tensor(math.sqrt(d_model), dtype=torch.float32))
115
+
116
+ def init_from_palette(self, palette_rgb: torch.Tensor):
117
+ """
118
+ Initialize palette key embeddings from RGB values.
119
+ palette_rgb: (palette_size, 3) tensor of RGB values [0, 255]
120
+ """
121
+ with torch.no_grad():
122
+ palette_size = palette_rgb.shape[0]
123
+ # Normalize RGB to [-1, 1] and project to d_model
124
+ rgb_norm = palette_rgb.float() / 127.5 - 1.0 # (palette_size, 3)
125
+ # Repeat/tile to fill d_model dimensions
126
+ repeats = self.d_model // 3 + 1
127
+ expanded = rgb_norm.repeat(1, repeats)[:, :self.d_model]
128
+ # Mix with some noise for diversity
129
+ self.palette_keys.data[:palette_size] = expanded + 0.1 * torch.randn_like(expanded)
130
+
131
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
132
+ """
133
+ Args:
134
+ hidden_states: (batch, seq_len, d_model)
135
+ Returns:
136
+ logits: (batch, seq_len, total_vocab)
137
+ """
138
+ queries = self.query_proj(hidden_states) # (batch, seq_len, d_model)
139
+ # Scaled dot-product attention with palette keys
140
+ logits = torch.matmul(queries, self.palette_keys.T) / self.temperature
141
+ return logits
142
+
143
+
144
+ class BitMultiheadAttention(nn.Module):
145
+ """
146
+ Multi-head attention with BitLinear158 projections.
147
+
148
+ Q, K, V projections and the output projection all use 1.58-bit weights.
149
+ Attention computation itself remains in full precision.
150
+
151
+ Following BitNet b1.58: the RMSNorm that normally precedes attention
152
+ is absorbed into the BitLinear158 layers (they have built-in RMSNorm).
153
+ """
154
+
155
+ def __init__(self, d_model: int, nhead: int, dropout: float = 0.0):
156
+ super().__init__()
157
+ assert d_model % nhead == 0, f"d_model ({d_model}) must be divisible by nhead ({nhead})"
158
+
159
+ self.d_model = d_model
160
+ self.nhead = nhead
161
+ self.head_dim = d_model // nhead
162
+
163
+ # QKV projections — all 1.58-bit
164
+ self.q_proj = BitLinear158(d_model, d_model)
165
+ self.k_proj = BitLinear158(d_model, d_model)
166
+ self.v_proj = BitLinear158(d_model, d_model)
167
+
168
+ # Output projection — 1.58-bit
169
+ self.out_proj = BitLinear158(d_model, d_model)
170
+
171
+ self.dropout = nn.Dropout(dropout)
172
+ self.scale = math.sqrt(self.head_dim)
173
+
174
+ def forward(
175
+ self,
176
+ query: torch.Tensor,
177
+ key: torch.Tensor,
178
+ value: torch.Tensor,
179
+ attn_mask: Optional[torch.Tensor] = None,
180
+ key_padding_mask: Optional[torch.Tensor] = None,
181
+ ) -> torch.Tensor:
182
+ """
183
+ Args:
184
+ query: (batch, q_len, d_model)
185
+ key: (batch, kv_len, d_model)
186
+ value: (batch, kv_len, d_model)
187
+ attn_mask: (q_len, kv_len) or (batch*nhead, q_len, kv_len)
188
+ key_padding_mask: (batch, kv_len)
189
+ Returns:
190
+ (batch, q_len, d_model)
191
+ """
192
+ batch_size = query.size(0)
193
+
194
+ # Project Q, K, V through 1.58-bit linear layers
195
+ q = self.q_proj(query)
196
+ k = self.k_proj(key)
197
+ v = self.v_proj(value)
198
+
199
+ # Reshape for multi-head: (batch, seq, d_model) -> (batch, nhead, seq, head_dim)
200
+ q = q.view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2)
201
+ k = k.view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2)
202
+ v = v.view(batch_size, -1, self.nhead, self.head_dim).transpose(1, 2)
203
+
204
+ # Scaled dot-product attention
205
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale
206
+
207
+ # Apply causal mask
208
+ if attn_mask is not None:
209
+ if attn_mask.dim() == 2:
210
+ attn_weights = attn_weights + attn_mask.unsqueeze(0).unsqueeze(0)
211
+ else:
212
+ attn_weights = attn_weights + attn_mask
213
+
214
+ # Apply padding mask
215
+ if key_padding_mask is not None:
216
+ attn_weights = attn_weights.masked_fill(
217
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
218
+ float('-inf')
219
+ )
220
+
221
+ attn_weights = F.softmax(attn_weights, dim=-1)
222
+ attn_weights = self.dropout(attn_weights)
223
+
224
+ # Apply attention to values
225
+ attn_output = torch.matmul(attn_weights, v)
226
+
227
+ # Reshape back: (batch, nhead, seq, head_dim) -> (batch, seq, d_model)
228
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
229
+
230
+ # Output projection (1.58-bit)
231
+ return self.out_proj(attn_output)
232
+
233
+
234
+ class BitPixelLMDecoderLayer(nn.Module):
235
+ """
236
+ Single decoder layer with 1.58-bit weights.
237
+
238
+ Structure (per BitNet b1.58 / LLaMA convention):
239
+ 1. Self-attention with BitLinear158 projections (RMSNorm built into BitLinear)
240
+ 2. Cross-attention to text encoder output (BitLinear158 projections)
241
+ 3. SwiGLU feed-forward network (BitLinear158 projections)
242
+
243
+ Pre-norm architecture, but the norm is absorbed into BitLinear158.
244
+ Residual connections use a separate RMSNorm for gradient stability.
245
+ """
246
+
247
+ def __init__(self, d_model: int, nhead: int, dim_ff: int, dropout: float = 0.0):
248
+ super().__init__()
249
+
250
+ # Self-attention (masked, causal)
251
+ self.self_attn = BitMultiheadAttention(d_model, nhead, dropout=dropout)
252
+ self.norm1 = RMSNorm(d_model)
253
+
254
+ # Cross-attention to text
255
+ self.cross_attn = BitMultiheadAttention(d_model, nhead, dropout=dropout)
256
+ self.norm2 = RMSNorm(d_model)
257
+
258
+ # SwiGLU feed-forward (replaces GELU FFN)
259
+ self.ff = SwiGLU(d_model, hidden_features=dim_ff, use_bitlinear=True)
260
+ self.norm3 = RMSNorm(d_model)
261
+
262
+ self.dropout = nn.Dropout(dropout)
263
+
264
+ def forward(
265
+ self,
266
+ x: torch.Tensor,
267
+ text_enc: torch.Tensor,
268
+ causal_mask: torch.Tensor,
269
+ text_pad_mask: Optional[torch.Tensor] = None,
270
+ ) -> torch.Tensor:
271
+ """
272
+ Args:
273
+ x: (batch, seq_len, d_model)
274
+ text_enc: (batch, text_len, d_model)
275
+ causal_mask: (seq_len, seq_len) causal attention mask
276
+ text_pad_mask: (batch, text_len) padding mask for text
277
+ Returns:
278
+ (batch, seq_len, d_model)
279
+ """
280
+ # Pre-norm self-attention with residual
281
+ residual = x
282
+ x = self.norm1(x)
283
+ x = self.self_attn(x, x, x, attn_mask=causal_mask)
284
+ x = self.dropout(x) + residual
285
+
286
+ # Pre-norm cross-attention with residual
287
+ residual = x
288
+ x = self.norm2(x)
289
+ x = self.cross_attn(x, text_enc, text_enc, key_padding_mask=text_pad_mask)
290
+ x = self.dropout(x) + residual
291
+
292
+ # Pre-norm SwiGLU FFN with residual
293
+ residual = x
294
+ x = self.norm3(x)
295
+ x = self.ff(x)
296
+ x = self.dropout(x) + residual
297
+
298
+ return x
299
+
300
+
301
+ class BitPixelLMDecoder(nn.Module):
302
+ """
303
+ 1.58-bit PixelLM Decoder.
304
+
305
+ Same architecture as PixelLMDecoder but with:
306
+ - BitLinear158 replacing all nn.Linear in attention and FFN
307
+ - RMSNorm replacing LayerNorm (absorbed into BitLinear + residual norms)
308
+ - SwiGLU replacing GELU FFN
309
+ - No biases
310
+
311
+ Full precision components (NOT quantized):
312
+ - Token embeddings (need full precision for gradient flow to embeddings)
313
+ - 2D positional encoding (our unique spatial encoding)
314
+ - Palette output head (needs high-precision logits for sampling)
315
+ """
316
+
317
+ def __init__(
318
+ self,
319
+ vocab_size: int,
320
+ d_model: int = 256,
321
+ nhead: int = 8,
322
+ num_layers: int = 6,
323
+ dim_feedforward: int = 512,
324
+ img_size: int = 32,
325
+ dropout: float = 0.1,
326
+ ):
327
+ super().__init__()
328
+ self.d_model = d_model
329
+ self.vocab_size = vocab_size
330
+ self.img_size = img_size
331
+ self.max_seq_len = img_size * img_size + 2
332
+
333
+ # ── Full precision components ─────────────────────────────
334
+ # Token embedding (kept in FP32)
335
+ self.token_embed = nn.Embedding(vocab_size, d_model)
336
+
337
+ # 2D positional encoding (our unique contribution — kept FP32)
338
+ self.pos_encoding = PixelPositionalEncoding2D(d_model, img_size)
339
+
340
+ # Palette-aware output head (kept FP32 for sampling precision)
341
+ self.output_head = PaletteOutputHead(d_model, vocab_size - 3, num_special_tokens=3)
342
+
343
+ # ── 1.58-bit components ───────────────────────────────────
344
+ # Decoder layers with BitLinear158
345
+ self.layers = nn.ModuleList([
346
+ BitPixelLMDecoderLayer(d_model, nhead, dim_feedforward, dropout)
347
+ for _ in range(num_layers)
348
+ ])
349
+
350
+ # Final norm (full precision RMSNorm)
351
+ self.final_norm = RMSNorm(d_model)
352
+
353
+ # Dropout
354
+ self.dropout = nn.Dropout(dropout)
355
+
356
+ # Cache for causal mask
357
+ self._causal_mask_cache = {}
358
+
359
+ def _get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
360
+ """Generate or retrieve cached causal attention mask."""
361
+ if seq_len not in self._causal_mask_cache:
362
+ mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
363
+ float_mask = torch.zeros(seq_len, seq_len, device=device)
364
+ float_mask.masked_fill_(mask, float('-inf'))
365
+ self._causal_mask_cache[seq_len] = float_mask
366
+ return self._causal_mask_cache[seq_len]
367
+
368
+ def forward(
369
+ self,
370
+ pixel_tokens: torch.Tensor,
371
+ text_enc: torch.Tensor,
372
+ text_pad_mask: Optional[torch.Tensor] = None,
373
+ ) -> torch.Tensor:
374
+ """
375
+ Forward pass for training (teacher-forced).
376
+
377
+ Args:
378
+ pixel_tokens: (batch, seq_len) long tensor of pixel token indices
379
+ text_enc: (batch, text_len, d_model) text encoder output
380
+ text_pad_mask: (batch, text_len) True where text is padded
381
+ Returns:
382
+ logits: (batch, seq_len, vocab_size)
383
+ """
384
+ batch_size, seq_len = pixel_tokens.shape
385
+ device = pixel_tokens.device
386
+
387
+ # Token embeddings (full precision)
388
+ x = self.token_embed(pixel_tokens) * math.sqrt(self.d_model)
389
+
390
+ # 2D positional encoding (full precision)
391
+ pos = self.pos_encoding(seq_len, device)
392
+ x = x + pos
393
+ x = self.dropout(x)
394
+
395
+ # Causal mask
396
+ causal_mask = self._get_causal_mask(seq_len, device)
397
+
398
+ # 1.58-bit decoder layers
399
+ for layer in self.layers:
400
+ x = layer(x, text_enc, causal_mask, text_pad_mask)
401
+
402
+ # Final norm
403
+ x = self.final_norm(x)
404
+
405
+ # Output logits via palette-aware head (full precision)
406
+ logits = self.output_head(x)
407
+
408
+ return logits
409
+
410
+ @torch.no_grad()
411
+ def generate(
412
+ self,
413
+ text_enc: torch.Tensor,
414
+ sos_token: int,
415
+ eos_token: int,
416
+ max_len: int = 1026,
417
+ temperature: float = 0.8,
418
+ top_k: int = 40,
419
+ top_p: float = 0.9,
420
+ text_pad_mask: Optional[torch.Tensor] = None,
421
+ ) -> torch.Tensor:
422
+ """
423
+ Autoregressive generation (same interface as PixelLMDecoder).
424
+ """
425
+ device = text_enc.device
426
+ tokens = torch.tensor([[sos_token]], dtype=torch.long, device=device)
427
+
428
+ for step in range(max_len - 1):
429
+ logits = self.forward(tokens, text_enc, text_pad_mask)
430
+ next_logits = logits[:, -1, :] / temperature
431
+
432
+ # Top-k filtering
433
+ if top_k > 0:
434
+ topk_vals, _ = torch.topk(next_logits, top_k)
435
+ next_logits[next_logits < topk_vals[:, -1:]] = float('-inf')
436
+
437
+ # Top-p (nucleus) filtering
438
+ if top_p < 1.0:
439
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
440
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
441
+ sorted_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p
442
+ sorted_logits[sorted_mask] = float('-inf')
443
+ next_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
444
+
445
+ probs = F.softmax(next_logits, dim=-1)
446
+ next_token = torch.multinomial(probs, 1)
447
+ tokens = torch.cat([tokens, next_token], dim=1)
448
+
449
+ if next_token.item() == eos_token:
450
+ break
451
+
452
+ return tokens
453
+
454
+
455
+ class BitPixelLM(nn.Module):
456
+ """
457
+ Complete 1.58-bit PixelLM: Text Encoder (FP32) + Pixel Decoder (1.58-bit).
458
+
459
+ The text encoder remains in full precision because:
460
+ 1. It's small (3 layers) — quantization overhead would negate benefits
461
+ 2. Text understanding needs full precision for a small vocabulary
462
+
463
+ The pixel decoder uses 1.58-bit weights for:
464
+ 1. All self-attention projections (Q, K, V, O)
465
+ 2. All cross-attention projections
466
+ 3. All FFN projections (SwiGLU)
467
+ """
468
+
469
+ def __init__(self, text_encoder: nn.Module, pixel_decoder: BitPixelLMDecoder):
470
+ super().__init__()
471
+ self.text_encoder = text_encoder
472
+ self.pixel_decoder = pixel_decoder
473
+
474
+ def forward(
475
+ self,
476
+ text_tokens: torch.Tensor,
477
+ pixel_tokens: torch.Tensor,
478
+ ) -> torch.Tensor:
479
+ text_pad_mask = (text_tokens == 0)
480
+ text_enc = self.text_encoder(text_tokens)
481
+ logits = self.pixel_decoder(pixel_tokens, text_enc, text_pad_mask)
482
+ return logits
483
+
484
+ @torch.no_grad()
485
+ def generate(
486
+ self,
487
+ text_tokens: torch.Tensor,
488
+ sos_token: int,
489
+ eos_token: int,
490
+ **kwargs,
491
+ ) -> torch.Tensor:
492
+ text_pad_mask = (text_tokens == 0)
493
+ text_enc = self.text_encoder(text_tokens)
494
+ return self.pixel_decoder.generate(
495
+ text_enc, sos_token, eos_token,
496
+ text_pad_mask=text_pad_mask, **kwargs
497
+ )
498
+
499
+ def count_parameters(self) -> int:
500
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
501
+
502
+ def count_bit_parameters(self) -> dict:
503
+ """Count parameters by precision level."""
504
+ bit_params = 0
505
+ fp_params = 0
506
+ for name, p in self.named_parameters():
507
+ if not p.requires_grad:
508
+ continue
509
+ if 'pixel_decoder.layers' in name and '.weight' in name and 'norm' not in name and 'rms_norm' not in name:
510
+ bit_params += p.numel()
511
+ else:
512
+ fp_params += p.numel()
513
+ return {
514
+ 'ternary_params': bit_params,
515
+ 'fp32_params': fp_params,
516
+ 'total': bit_params + fp_params,
517
+ 'ternary_pct': bit_params / (bit_params + fp_params) * 100,
518
+ 'effective_bits': (bit_params * 1.58 + fp_params * 32) / (bit_params + fp_params),
519
+ }
520
+
521
+
522
+ # ──── Testing ────────────────────────────────────────────────────
523
+
524
+ if __name__ == "__main__":
525
+ import sys
526
+ sys.path.insert(0, str(__import__('pathlib').Path(__file__).parent.parent))
527
+
528
+ from model.text_encoder import TextEncoder
529
+
530
+ print("Building BitPixelLM...")
531
+
532
+ # Build text encoder (full precision)
533
+ text_encoder = TextEncoder(
534
+ vocab_size=66, # 62 words + 4 special
535
+ d_model=256,
536
+ nhead=4,
537
+ num_layers=3,
538
+ dim_feedforward=512,
539
+ max_seq_len=32,
540
+ )
541
+
542
+ # Build 1.58-bit pixel decoder
543
+ pixel_decoder = BitPixelLMDecoder(
544
+ vocab_size=259,
545
+ d_model=256,
546
+ nhead=8,
547
+ num_layers=6,
548
+ dim_feedforward=512,
549
+ img_size=32,
550
+ )
551
+
552
+ model = BitPixelLM(text_encoder, pixel_decoder)
553
+
554
+ # Parameter count
555
+ total = model.count_parameters()
556
+ breakdown = model.count_bit_parameters()
557
+ print(f"\nBitPixelLM: {total:,} total parameters")
558
+ print(f" Ternary (1.58-bit): {breakdown['ternary_params']:,} ({breakdown['ternary_pct']:.1f}%)")
559
+ print(f" Full precision: {breakdown['fp32_params']:,} ({100-breakdown['ternary_pct']:.1f}%)")
560
+ print(f" Effective bits/param: {breakdown['effective_bits']:.2f}")
561
+
562
+ # Forward pass test
563
+ text = torch.randint(0, 66, (2, 32))
564
+ pixels = torch.randint(0, 259, (2, 1025))
565
+
566
+ print(f"\nForward pass test...")
567
+ logits = model(text, pixels)
568
+ print(f" Input: text={text.shape}, pixels={pixels.shape}")
569
+ print(f" Output: logits={logits.shape}")
570
+
571
+ # Gradient test
572
+ loss = logits[:, :, :259].sum()
573
+ loss.backward()
574
+ grad_ok = all(p.grad is not None for p in model.parameters() if p.requires_grad)
575
+ print(f" Gradient flow: {'OK' if grad_ok else 'FAILED'}")
576
+
577
+ print("\nAll tests passed! ✓")
model/bitlinear.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PixelArtGen — BitLinear 1.58-bit Layer & RMSNorm
3
+
4
+ Implementation of the core BitNet b1.58 components:
5
+ - RMSNorm: Root Mean Square Layer Normalization (Zhang & Sennrich, 2019)
6
+ - BitLinear158: 1.58-bit linear layer with ternary weights {-1, 0, +1}
7
+
8
+ References:
9
+ - "The Era of 1-bit LLMs" (Ma et al., 2024) — arXiv:2402.17764
10
+ - "BitNet: Scaling 1-bit Transformers" (Wang et al., 2023) — arXiv:2310.11453
11
+ - "RMSNorm" (Zhang & Sennrich, 2019) — arXiv:1910.07467
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import math
18
+
19
+
20
+ class RMSNorm(nn.Module):
21
+ """
22
+ Root Mean Square Layer Normalization.
23
+
24
+ Simpler and faster than LayerNorm — removes mean centering,
25
+ keeps only the re-scaling by root mean square.
26
+
27
+ RMSNorm(x) = x / RMS(x) * g
28
+ where RMS(x) = sqrt(mean(x^2))
29
+
30
+ Reference: arXiv:1910.07467
31
+ """
32
+
33
+ def __init__(self, dim: int, eps: float = 1e-6):
34
+ super().__init__()
35
+ self.eps = eps
36
+ self.weight = nn.Parameter(torch.ones(dim))
37
+
38
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
39
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ output = self._norm(x.float()).type_as(x)
43
+ return output * self.weight
44
+
45
+
46
+ def activation_quant(x: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Per-token 8-bit activation quantization from BitNet b1.58.
49
+
50
+ Quantizes activations to [-127, 127] per-token using absmax scaling.
51
+ Symmetric quantization (no zero-point) as specified in the paper.
52
+
53
+ Args:
54
+ x: (..., d_model) float tensor
55
+ Returns:
56
+ Quantized tensor (still float for autograd compatibility), scale factor
57
+ """
58
+ Qb = 127 # 8-bit signed: 2^(8-1) - 1
59
+ scale = x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
60
+ x_quant = (x * Qb / scale).clamp(-Qb, Qb).round()
61
+ # STE: detach the rounding, keep gradients flowing
62
+ x_quant = x + (x_quant * scale / Qb - x).detach()
63
+ return x_quant
64
+
65
+
66
+ def weight_quant(w: torch.Tensor) -> tuple:
67
+ """
68
+ Absmean ternary weight quantization from BitNet b1.58.
69
+
70
+ Quantizes weights to {-1, 0, +1} using absmean scaling:
71
+ 1. Compute gamma = mean(|W|)
72
+ 2. Scale: W_scaled = W / gamma
73
+ 3. Round to nearest in {-1, 0, +1}
74
+
75
+ Args:
76
+ w: (out_features, in_features) weight matrix
77
+ Returns:
78
+ (quantized_weights, scale_factor)
79
+ """
80
+ gamma = w.abs().mean().clamp(min=1e-5)
81
+ w_scaled = w / gamma
82
+ w_quant = w_scaled.clamp(-1, 1).round()
83
+ # STE: detach the rounding, keep gradients on the latent weights
84
+ w_quant = w + (w_quant * gamma - w).detach()
85
+ return w_quant, gamma
86
+
87
+
88
+ class BitLinear158(nn.Module):
89
+ """
90
+ 1.58-bit Linear Layer from BitNet b1.58.
91
+
92
+ Drop-in replacement for nn.Linear with:
93
+ - Ternary weights {-1, 0, +1} via absmean quantization
94
+ - 8-bit per-token activation quantization
95
+ - Built-in RMSNorm (absorbs the preceding LayerNorm)
96
+ - No bias (following BitNet b1.58 / LLaMA convention)
97
+ - Full-precision latent weights maintained for training (STE)
98
+
99
+ Forward pass:
100
+ 1. RMSNorm the input
101
+ 2. Quantize activations to 8-bit
102
+ 3. Quantize weights to ternary
103
+ 4. Matrix multiply (effectively integer addition)
104
+ 5. Rescale output
105
+
106
+ During training, gradients flow through quantization via the
107
+ Straight-Through Estimator (STE) — the gradient of round()
108
+ is treated as the identity function.
109
+
110
+ Reference: arXiv:2402.17764
111
+ """
112
+
113
+ def __init__(self, in_features: int, out_features: int):
114
+ super().__init__()
115
+ self.in_features = in_features
116
+ self.out_features = out_features
117
+
118
+ # Full-precision latent weight (master copy for training)
119
+ self.weight = nn.Parameter(torch.empty(out_features, in_features))
120
+
121
+ # Built-in RMSNorm (replaces the preceding LayerNorm)
122
+ self.rms_norm = RMSNorm(in_features)
123
+
124
+ # Initialize weights
125
+ self._init_weights()
126
+
127
+ def _init_weights(self):
128
+ """Kaiming uniform initialization, same as nn.Linear."""
129
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
130
+
131
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
132
+ """
133
+ Args:
134
+ x: (batch, seq_len, in_features)
135
+ Returns:
136
+ (batch, seq_len, out_features)
137
+ """
138
+ # 1. Normalize input (built-in RMSNorm)
139
+ x = self.rms_norm(x)
140
+
141
+ # 2. Quantize activations to 8-bit per-token
142
+ x_q = activation_quant(x)
143
+
144
+ # 3. Quantize weights to ternary {-1, 0, +1}
145
+ w_q, w_scale = weight_quant(self.weight)
146
+
147
+ # 4. Matrix multiply with quantized weights and activations
148
+ # In theory this is integer addition; in practice we use float
149
+ # for autograd compatibility during training
150
+ output = F.linear(x_q, w_q)
151
+
152
+ return output
153
+
154
+ def extra_repr(self) -> str:
155
+ return f"in={self.in_features}, out={self.out_features}, bits=1.58"
156
+
157
+
158
+ class SwiGLU(nn.Module):
159
+ """
160
+ SwiGLU activation for Feed-Forward Networks.
161
+
162
+ SwiGLU(x) = (Swish(xW1) ⊙ xV) W2
163
+
164
+ Uses 3 linear projections instead of 2, but the hidden dim
165
+ is typically reduced by 2/3 to keep parameter count similar.
166
+
167
+ When used with BitLinear158, all three projections are ternary.
168
+
169
+ Reference: arXiv:2002.05202 (Shazeer, 2020)
170
+ """
171
+
172
+ def __init__(self, in_features: int, hidden_features: int = None, use_bitlinear: bool = True):
173
+ super().__init__()
174
+ hidden_features = hidden_features or int(in_features * 8 / 3) # 2/3 of 4x expansion
175
+ # Round to nearest multiple of 8 for efficiency
176
+ hidden_features = ((hidden_features + 7) // 8) * 8
177
+
178
+ Linear = BitLinear158 if use_bitlinear else nn.Linear
179
+
180
+ self.w1 = Linear(in_features, hidden_features) # gate projection
181
+ self.v = Linear(in_features, hidden_features) # value projection
182
+ self.w2 = Linear(hidden_features, in_features) # output projection
183
+
184
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
185
+ return self.w2(F.silu(self.w1(x)) * self.v(x))
186
+
187
+
188
+ # ──── Testing ────────────────────────────────────────────────────
189
+
190
+ if __name__ == "__main__":
191
+ print("Testing BitLinear158 components...")
192
+
193
+ # Test RMSNorm
194
+ norm = RMSNorm(256)
195
+ x = torch.randn(2, 10, 256)
196
+ y = norm(x)
197
+ print(f"RMSNorm: {x.shape} -> {y.shape}, mean={y.mean():.4f}, std={y.std():.4f}")
198
+
199
+ # Test weight quantization
200
+ w = torch.randn(512, 256)
201
+ w_q, scale = weight_quant(w)
202
+ unique = torch.unique(w_q.detach())
203
+ print(f"Weight quant: {w.shape}, unique values: {len(unique)}, scale: {scale:.4f}")
204
+ print(f" Ternary distribution: -1={((w_q.detach().round() == -1).sum().item())}, "
205
+ f"0={((w_q.detach().round() == 0).sum().item())}, "
206
+ f"+1={((w_q.detach().round() == 1).sum().item())}")
207
+
208
+ # Test activation quantization
209
+ a = torch.randn(2, 10, 256)
210
+ a_q = activation_quant(a)
211
+ print(f"Activation quant: range [{a_q.min():.2f}, {a_q.max():.2f}]")
212
+
213
+ # Test BitLinear158
214
+ layer = BitLinear158(256, 512)
215
+ x = torch.randn(2, 10, 256)
216
+ y = layer(x)
217
+ print(f"BitLinear158: {x.shape} -> {y.shape}")
218
+
219
+ # Test gradient flow (STE)
220
+ loss = y.sum()
221
+ loss.backward()
222
+ assert layer.weight.grad is not None, "Gradient did not flow through STE!"
223
+ print(f"STE gradient flow: OK (grad norm: {layer.weight.grad.norm():.4f})")
224
+
225
+ # Test SwiGLU
226
+ swiglu = SwiGLU(256, use_bitlinear=True)
227
+ x = torch.randn(2, 10, 256)
228
+ y = swiglu(x)
229
+ print(f"SwiGLU (BitLinear): {x.shape} -> {y.shape}")
230
+ total = sum(p.numel() for p in swiglu.parameters())
231
+ print(f" SwiGLU params: {total:,}")
232
+
233
+ # Parameter comparison
234
+ ff_standard = nn.Sequential(nn.Linear(256, 512), nn.GELU(), nn.Linear(512, 256))
235
+ ff_params = sum(p.numel() for p in ff_standard.parameters())
236
+ print(f" Standard FFN params: {ff_params:,}")
237
+ print(f" Ratio: {total / ff_params:.2f}x")
238
+
239
+ print("\nAll tests passed! ✓")
model/text_encoder.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PixelArtGen — Text Encoder
3
+
4
+ A small transformer encoder that converts text prompts into
5
+ contextual embeddings for conditioning the pixel art decoder.
6
+ """
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from typing import List
13
+
14
+
15
+ class TextTokenizer:
16
+ """Simple word-level tokenizer for text prompts."""
17
+
18
+ def __init__(self, vocab: List[str]):
19
+ self.word2idx = {w: i for i, w in enumerate(vocab)}
20
+ self.idx2word = {i: w for i, w in enumerate(vocab)}
21
+ self.pad_idx = self.word2idx.get("<pad>", 0)
22
+ self.sos_idx = self.word2idx.get("<sos>", 1)
23
+ self.eos_idx = self.word2idx.get("<eos>", 2)
24
+ self.unk_idx = self.word2idx.get("<unk>", 3)
25
+ self.vocab_size = len(vocab)
26
+
27
+ def encode(self, text: str, max_len: int = 32) -> torch.Tensor:
28
+ """Tokenize and pad a text prompt."""
29
+ words = text.lower().strip().split()
30
+ tokens = [self.sos_idx]
31
+ for w in words:
32
+ tokens.append(self.word2idx.get(w, self.unk_idx))
33
+ tokens.append(self.eos_idx)
34
+
35
+ # Pad or truncate
36
+ if len(tokens) > max_len:
37
+ tokens = tokens[:max_len]
38
+ else:
39
+ tokens += [self.pad_idx] * (max_len - len(tokens))
40
+
41
+ return torch.tensor(tokens, dtype=torch.long)
42
+
43
+ def encode_batch(self, texts: List[str], max_len: int = 32) -> torch.Tensor:
44
+ """Encode a batch of text prompts."""
45
+ return torch.stack([self.encode(t, max_len) for t in texts])
46
+
47
+
48
+ class TextEncoder(nn.Module):
49
+ """
50
+ Small transformer encoder for text prompts.
51
+
52
+ Architecture:
53
+ - Word embeddings + sinusoidal positional encoding
54
+ - N transformer encoder layers with multi-head attention
55
+ - Output: sequence of contextual embeddings (batch, seq_len, d_model)
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ vocab_size: int,
61
+ d_model: int = 256,
62
+ nhead: int = 4,
63
+ num_layers: int = 3,
64
+ dim_feedforward: int = 512,
65
+ max_seq_len: int = 32,
66
+ dropout: float = 0.1,
67
+ ):
68
+ super().__init__()
69
+ self.d_model = d_model
70
+ self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
71
+ self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len)
72
+ self.dropout = nn.Dropout(dropout)
73
+
74
+ encoder_layer = nn.TransformerEncoderLayer(
75
+ d_model=d_model,
76
+ nhead=nhead,
77
+ dim_feedforward=dim_feedforward,
78
+ dropout=dropout,
79
+ batch_first=True,
80
+ norm_first=True,
81
+ )
82
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
83
+ self.norm = nn.LayerNorm(d_model)
84
+
85
+ def forward(self, text_tokens: torch.Tensor) -> torch.Tensor:
86
+ """
87
+ Args:
88
+ text_tokens: (batch, seq_len) long tensor of word indices
89
+ Returns:
90
+ (batch, seq_len, d_model) contextual embeddings
91
+ """
92
+ # Create padding mask (True = ignore)
93
+ pad_mask = (text_tokens == 0) # pad_idx = 0
94
+
95
+ # Embed + positional encode
96
+ x = self.embedding(text_tokens) * math.sqrt(self.d_model)
97
+ x = self.pos_encoding(x)
98
+ x = self.dropout(x)
99
+
100
+ # Transformer encode
101
+ x = self.transformer(x, src_key_padding_mask=pad_mask)
102
+ x = self.norm(x)
103
+
104
+ return x
105
+
106
+
107
+ class SinusoidalPositionalEncoding(nn.Module):
108
+ """Standard sinusoidal positional encoding."""
109
+
110
+ def __init__(self, d_model: int, max_len: int = 512):
111
+ super().__init__()
112
+ pe = torch.zeros(max_len, d_model)
113
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
114
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
115
+
116
+ pe[:, 0::2] = torch.sin(position * div_term)
117
+ pe[:, 1::2] = torch.cos(position * div_term)
118
+ pe = pe.unsqueeze(0) # (1, max_len, d_model)
119
+ self.register_buffer("pe", pe)
120
+
121
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
122
+ return x + self.pe[:, :x.size(1)]
model/tokenizer.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PixelArtGen — Color Palette Tokenizer
3
+
4
+ Converts 32×32 RGB pixel art images into sequences of palette indices
5
+ and back. This is the "vocabulary" for the pixel language model.
6
+ """
7
+
8
+ import numpy as np
9
+ import torch
10
+ from pathlib import Path
11
+
12
+
13
+ class PaletteTokenizer:
14
+ """
15
+ Maps RGB pixels to/from a fixed palette of N colors.
16
+ Each pixel becomes a token index ∈ [0, palette_size).
17
+
18
+ Special tokens:
19
+ palette_size = <sos> (start of sequence)
20
+ palette_size + 1 = <eos> (end of sequence)
21
+ palette_size + 2 = <pad> (padding)
22
+ """
23
+
24
+ def __init__(self, palette_path: str = None, palette: np.ndarray = None, palette_size: int = 256):
25
+ if palette is not None:
26
+ self.palette = palette.astype(np.float32)
27
+ elif palette_path is not None:
28
+ self.palette = np.load(palette_path).astype(np.float32)
29
+ else:
30
+ raise ValueError("Must provide palette_path or palette array")
31
+
32
+ self.palette_size = len(self.palette)
33
+ self.sos_token = self.palette_size
34
+ self.eos_token = self.palette_size + 1
35
+ self.pad_token = self.palette_size + 2
36
+ self.vocab_size = self.palette_size + 3 # colors + sos + eos + pad
37
+
38
+ def rgb_to_index(self, rgb: np.ndarray) -> int:
39
+ """Find the closest palette color for an RGB value."""
40
+ distances = np.sum((self.palette - rgb.astype(np.float32)) ** 2, axis=1)
41
+ return int(np.argmin(distances))
42
+
43
+ def encode_image(self, img_array: np.ndarray) -> list:
44
+ """
45
+ Encode a 32×32×3 RGB image into a flat sequence of palette indices.
46
+ Returns: [sos, p0, p1, ..., p1023, eos] (1026 tokens)
47
+ """
48
+ h, w, c = img_array.shape
49
+ assert h == 32 and w == 32 and c == 3, f"Expected 32×32×3, got {img_array.shape}"
50
+
51
+ tokens = [self.sos_token]
52
+ for y in range(h):
53
+ for x in range(w):
54
+ pixel = img_array[y, x]
55
+ idx = self.rgb_to_index(pixel)
56
+ tokens.append(idx)
57
+ tokens.append(self.eos_token)
58
+ return tokens
59
+
60
+ def encode_image_fast(self, img_array: np.ndarray) -> list:
61
+ """
62
+ Vectorized encoding — much faster than pixel-by-pixel.
63
+ """
64
+ h, w, c = img_array.shape
65
+ pixels = img_array.reshape(-1, 3).astype(np.float32) # (1024, 3)
66
+
67
+ # Compute distances to all palette colors at once
68
+ # pixels: (1024, 3), palette: (N, 3)
69
+ diff = pixels[:, None, :] - self.palette[None, :, :] # (1024, N, 3)
70
+ distances = np.sum(diff ** 2, axis=2) # (1024, N)
71
+ indices = np.argmin(distances, axis=1) # (1024,)
72
+
73
+ tokens = [self.sos_token] + indices.tolist() + [self.eos_token]
74
+ return tokens
75
+
76
+ def decode_tokens(self, tokens: list) -> np.ndarray:
77
+ """
78
+ Decode a sequence of palette indices back to a 32×32×3 RGB image.
79
+ Strips sos/eos/pad tokens.
80
+ """
81
+ # Filter special tokens
82
+ pixel_tokens = [t for t in tokens if t < self.palette_size]
83
+
84
+ # Pad or truncate to exactly 1024 pixels
85
+ if len(pixel_tokens) < 1024:
86
+ pixel_tokens += [0] * (1024 - len(pixel_tokens))
87
+ pixel_tokens = pixel_tokens[:1024]
88
+
89
+ img = np.zeros((1024, 3), dtype=np.uint8)
90
+ for i, idx in enumerate(pixel_tokens):
91
+ idx = min(idx, self.palette_size - 1)
92
+ img[i] = self.palette[idx].astype(np.uint8)
93
+
94
+ return img.reshape(32, 32, 3)
95
+
96
+ def tokens_to_tensor(self, tokens: list, max_len: int = 1026) -> torch.Tensor:
97
+ """Convert token list to padded tensor."""
98
+ if len(tokens) > max_len:
99
+ tokens = tokens[:max_len]
100
+ else:
101
+ tokens = tokens + [self.pad_token] * (max_len - len(tokens))
102
+ return torch.tensor(tokens, dtype=torch.long)
103
+
104
+ def get_palette_tensor(self) -> torch.Tensor:
105
+ """Return the palette as a (palette_size, 3) float32 tensor."""
106
+ return torch.tensor(self.palette, dtype=torch.float32)