BitPixelLM / model /tokenizer.py
BlakePeavy's picture
Upload BitPixelLM model artifacts
72e872c verified
"""
PixelArtGen — Color Palette Tokenizer
Converts 32×32 RGB pixel art images into sequences of palette indices
and back. This is the "vocabulary" for the pixel language model.
"""
import numpy as np
import torch
from pathlib import Path
class PaletteTokenizer:
"""
Maps RGB pixels to/from a fixed palette of N colors.
Each pixel becomes a token index ∈ [0, palette_size).
Special tokens:
palette_size = <sos> (start of sequence)
palette_size + 1 = <eos> (end of sequence)
palette_size + 2 = <pad> (padding)
"""
def __init__(self, palette_path: str = None, palette: np.ndarray = None, palette_size: int = 256):
if palette is not None:
self.palette = palette.astype(np.float32)
elif palette_path is not None:
self.palette = np.load(palette_path).astype(np.float32)
else:
raise ValueError("Must provide palette_path or palette array")
self.palette_size = len(self.palette)
self.sos_token = self.palette_size
self.eos_token = self.palette_size + 1
self.pad_token = self.palette_size + 2
self.vocab_size = self.palette_size + 3 # colors + sos + eos + pad
def rgb_to_index(self, rgb: np.ndarray) -> int:
"""Find the closest palette color for an RGB value."""
distances = np.sum((self.palette - rgb.astype(np.float32)) ** 2, axis=1)
return int(np.argmin(distances))
def encode_image(self, img_array: np.ndarray) -> list:
"""
Encode a 32×32×3 RGB image into a flat sequence of palette indices.
Returns: [sos, p0, p1, ..., p1023, eos] (1026 tokens)
"""
h, w, c = img_array.shape
assert h == 32 and w == 32 and c == 3, f"Expected 32×32×3, got {img_array.shape}"
tokens = [self.sos_token]
for y in range(h):
for x in range(w):
pixel = img_array[y, x]
idx = self.rgb_to_index(pixel)
tokens.append(idx)
tokens.append(self.eos_token)
return tokens
def encode_image_fast(self, img_array: np.ndarray) -> list:
"""
Vectorized encoding — much faster than pixel-by-pixel.
"""
h, w, c = img_array.shape
pixels = img_array.reshape(-1, 3).astype(np.float32) # (1024, 3)
# Compute distances to all palette colors at once
# pixels: (1024, 3), palette: (N, 3)
diff = pixels[:, None, :] - self.palette[None, :, :] # (1024, N, 3)
distances = np.sum(diff ** 2, axis=2) # (1024, N)
indices = np.argmin(distances, axis=1) # (1024,)
tokens = [self.sos_token] + indices.tolist() + [self.eos_token]
return tokens
def decode_tokens(self, tokens: list) -> np.ndarray:
"""
Decode a sequence of palette indices back to a 32×32×3 RGB image.
Strips sos/eos/pad tokens.
"""
# Filter special tokens
pixel_tokens = [t for t in tokens if t < self.palette_size]
# Pad or truncate to exactly 1024 pixels
if len(pixel_tokens) < 1024:
pixel_tokens += [0] * (1024 - len(pixel_tokens))
pixel_tokens = pixel_tokens[:1024]
img = np.zeros((1024, 3), dtype=np.uint8)
for i, idx in enumerate(pixel_tokens):
idx = min(idx, self.palette_size - 1)
img[i] = self.palette[idx].astype(np.uint8)
return img.reshape(32, 32, 3)
def tokens_to_tensor(self, tokens: list, max_len: int = 1026) -> torch.Tensor:
"""Convert token list to padded tensor."""
if len(tokens) > max_len:
tokens = tokens[:max_len]
else:
tokens = tokens + [self.pad_token] * (max_len - len(tokens))
return torch.tensor(tokens, dtype=torch.long)
def get_palette_tensor(self) -> torch.Tensor:
"""Return the palette as a (palette_size, 3) float32 tensor."""
return torch.tensor(self.palette, dtype=torch.float32)