STELLAR — Sparse Visual Representations via Spatial–Semantic Factorization
STELLAR learns a unified sparse visual representation that supports both reconstruction and semantics using as few as 16 tokens. By factorizing "what" (semantics) from "where" (spatial layout), each image is encoded as the low-rank product of a localization matrix and a semantics matrix.
- 📄 Paper: arXiv:2602.01905 (ICML 2026)
- 💻 Code: github.com/microsoft/STELLAR
These checkpoints contain the full set of trained STELLAR modules (encoder, sparse tokens, projections, reconstruction decoder, and clustering heads), so a single file supports feature extraction, image reconstruction, and continued pretraining. All models are self-supervised on ImageNet-1K at 224×224.
Highlights
- Sparse & unified — one small set of tokens serves both high-level semantics and pixel-level reconstruction.
- Factorized latents — each token captures a concept (what) together with a spatial map of where it appears.
- Strong on both axes — STELLAR-H reaches 2.60 FID (reconstruction) and 79.1% ImageNet linear-probing accuracy with just 16 tokens.
Available models
| Model | Backbone | Tokens | Params | Type | File |
|---|---|---|---|---|---|
stellar-b16 |
ViT-B/16 | 16 | 88M | main | stellar-b16.safetensors |
stellar-l16 |
ViT-L/16 | 16 | 307M | main | stellar-l16.safetensors |
stellar-h16 |
ViT-H/14 | 16 | 636M | main | stellar-h16.safetensors |
stellar-b8 |
ViT-B/16 | 8 | 88M | ablation | stellar-b8.safetensors |
stellar-b24 |
ViT-B/16 | 24 | 88M | ablation | stellar-b24.safetensors |
The main models (b16, l16, h16) are recommended for downstream use; the 8- and
24-token base models are ablations on the number of sparse tokens.
Usage
Install the STELLAR code and the Hub helpers:
pip install huggingface_hub safetensors
git clone https://github.com/microsoft/STELLAR && cd STELLAR
pip install -r requirements.txt
Quick start
From the STELLAR code directory, use the load_stellar.py helper
(it downloads the weights from the Hub for you):
import torch
from load_stellar import load_stellar, list_models
print(list_models()) # ['stellar-b16', 'stellar-l16', ...]
model = load_stellar("stellar-b16") # purpose="encode" (default)
# RGB image in [0, 1], resized to 224×224 (ImageNet normalization is applied internally)
image = torch.rand(1, 3, 224, 224)
with torch.no_grad():
out = model.encode(image)
out["sparse"] # (1, K, D) sparse concept tokens ("what")
out["spatial"] # (1, P, K) per-token spatial maps ("where")
out["dense"] # (1, P, D) dense per-patch features
out["cls"] # (1, 1, D) global image token
Reconstruction & continued pretraining
The same checkpoint can be loaded for other purposes via the purpose argument. Image
reconstruction and continued pretraining use the decoder, which predicts
MaskGIT-VQGAN tokens — pass the tokenizer
path as vq_model:
# 1. encode -> factorized features (sparse concept tokens + spatial maps)
model = load_stellar("stellar-b16", purpose="reconstruct", vq_model=VQGAN_PATH)
features = model.encode(image) # dict: sparse (B,K,D), spatial (B,P,K), ...
# 2. decode the factorized features -> VQGAN decoder -> pixels
out = model.reconstruct(features) # or model.reconstruct(features["sparse"], features["spatial"])
pixels = out["reconstruction"] # (B, 3, H, W) RGB in [0, 1]
# 224x224 for /16 models, 256x256 for the /14 H model
# out["tokens"] : (B, P) predicted VQGAN token ids
# out["logits"] : (B, P, 1024) raw codebook logits
# continued pretraining (all modules, gradients enabled)
model = load_stellar("stellar-b16", purpose="pretrain", vq_model=VQGAN_PATH)
losses = model({"image": image, "labels": labels, ...})["predictions"]
reconstruct is the decoder half of STELLAR: it takes the factorized features and
runs low-rank dense map → ViT decoder → VQGAN decoder to return RGB pixels. See
examples/reconstruction.ipynb for an end-to-end demo
that loads an image and displays the reconstruction.
What the model returns
| Key | Shape | Description | Typical use |
|---|---|---|---|
sparse |
(B, K, D) |
sparse concept tokens | classification, retrieval |
spatial |
(B, P, K) |
spatial map of each token | segmentation, visualization |
dense |
(B, P, D) |
dense per-patch features | segmentation |
lowrank |
(B, P, D) |
reassembled dense map | reconstruction |
cls |
(B, 1, D) |
global representation | classification |
B = batch, K = number of sparse tokens, P = number of patches (196 for /16 at
224², 256 for /14), D = embedding dim (768 / 1024 / 1280 for B / L / H).
Loading the weights manually
import json, torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from src.models.stellar_model import STELLARModel
repo = "microsoft/STELLAR"
cfg = json.load(open(hf_hub_download(repo, "config.json")))["models"]["stellar-b16"]
state = load_file(hf_hub_download(repo, cfg["weights"]))
model = STELLARModel(
num_sparse_tokens=cfg["num_sparse_tokens"],
num_decoder_layers=cfg["num_decoder_layers"],
spatial_temp=cfg["spatial_temp"],
vit_pretrained=cfg["backbone"],
do_recon=False, do_clustering=False, vq_model=None,
)
model.load_state_dict(state, strict=False) # encoder-only build ignores decoder/heads
model.eval()
features = model.encode(torch.rand(1, 3, 224, 224))
Tip: download with
huggingface_hub(as above) rather thangit cloneso that downloads are registered on the Hub —git cloneis not counted in download stats.
Model details
- Architecture: ViT encoder (MAE-initialized) + learned sparse latent queries with spatial–semantic factorization.
- Pretraining data: ImageNet-1K (self-supervised; labels not used).
- Input: RGB images in
[0, 1], resized to 224×224 (bicubic). ImageNet mean/std normalization is applied inside the model — pass raw[0, 1]images. - Weights: the complete set of trained STELLAR modules (encoder, sparse tokens,
projections, reconstruction decoder, and clustering heads), stored in
safetensors. Only the third-party MaskGIT-VQGAN tokenizer is excluded — it is downloaded separately (from TiTok) and passed viavq_model. - Framework: PyTorch.
Intended uses & limitations
- Intended use: extracting compact sparse/dense visual features for downstream recognition, segmentation, retrieval, reconstruction, and analysis.
- Limitations: pretrained on ImageNet-1K at 224×224, so features reflect that distribution; performance on very different domains (e.g. medical, satellite) may require fine-tuning. The models are research artifacts and are not safety-tested for production decision-making.
Citation
@inproceedings{zhao2026stellar,
title = {Learning Sparse Visual Representations via Spatial-Semantic Factorization},
author = {Zhao, Theodore Zhengde and Kiblawi, Sid and Yang, Jianwei and Usuyama, Naoto and Tan, Reuben and Codella, Noel C and Naumann, Tristan and Poon, Hoifung and Wei, Mu},
booktitle = {International Conference on Machine Learning (ICML)},
year = {2026},
url = {https://arxiv.org/abs/2602.01905},
}
License
Released under the MIT License.
- Downloads last month
- -