LocateAnything-3B โ ONNX WebGPU (INT4 + 4-bit embeddings)
In-browser (onnxruntime-web / WebGPU) build of
nvidia/LocateAnything-3B, a visual-grounding /
open-vocabulary detector. The language tower is weight-only INT4 and the embedding table is
true group-wise 4-bit.
Why this repo exists
The naive "4-bit ONNX" of this model was ~3GB because the model has tied word embeddings
(vocab 152681 ร hidden 2048 = 1.25GB in fp32). ORT's MatMulNBits INT4 quantizer compresses the
tied lm_head MatMul but leaves the input-embedding Gather at full fp32 โ so 1.25GB of fp32
embeddings stayed in the package.
This build fixes that with a custom quantized embedding gather:
- The language graph was surgically rewired to consume
inputs_embedsdirectly (the fp32 embeddingGatherand its 1.25GB initializer are removed). It still takesinput_ids(used only by the SDLM block-mask==comparisons, not a Gather) andvisual_features(spliced at the image token). - The embedding table ships as a group-wise symmetric INT4 blob (
(q-8)ยทscale, block size 32):embed_tokens_int4_packed.bin(uint8 nibble-packed) +embed_tokens_int4_scales.bin(fp16). - The browser does the gather + dequant in JS to build
inputs_embeds, then runs the INT4 language graph.
Files (browser-facing)
| File | Size | Notes |
|---|---|---|
onnx/vision_mlp.onnx (+.data) |
~1.73 GB | MoonViT + projector, fp32 (see note below) |
onnx/language_tail_int4.onnx (+.data) |
~1.69 GB | Qwen2 language tower + tied lm_head, weight-only INT4 (block 128) |
onnx/embed_tokens_int4_packed.bin |
~156 MB | INT4 embedding table, uint8 nibble-packed [152681, 1024] |
onnx/embed_tokens_int4_scales.bin |
~19.5 MB | fp16 group scales [152681, 64] |
onnx/embed_tokens_int4_meta.json |
โ | layout / dequant scheme |
web_config.json |
โ | runtime wiring, token ids, tail size |
Total browser payload โ 3.6 GB. The big win here is the language side: the embedding table dropped from 1.25 GB fp32 โ 176 MB INT4 and the language tail from 2.9 GB โ 1.69 GB.
Vision precision note. The vision tower's linears export as ONNX
Gemm/ dynamicMatMul, which ORT'sMatMulNBitsINT4 quantizer cannot compress, so it ships fp32. fp16 conversion is blocked by the explicit.float()Cast islands in the ONNX-friendly MoonViT RoPE patch (post-hoc fp16 conversion produces type clashes; native fp16 export hits a torch/MPSexpand_as+float64 limitation). A native mixed-precision vision re-export (Conv in fp32, rest fp16) is the planned follow-up to cut this to ~0.9 GB.
Embedding gather / dequant (JS reference)
row = packed[token_id] // uint8[hidden/2]
low = row & 0x0F ; high = (row >> 4) // two nibbles per byte (low = even idx, high = odd idx)
q = interleave(low, high) // uint4[hidden], values 0..15
emb = (q - 8) * scales[token_id][j/32] // fp32[hidden]; one scale per 32-wide group
The language graph then splices visual_features over the image-token positions and applies the
SDLM block mask from input_ids.
Validation
Validated against the fp32 PyTorch model on the sample image (slow / autoregressive mode):
- Next-token argmax matches PyTorch exactly (token
151672=<ref>start). - INT4 embedding gather error vs fp32 embeddings:
max_abs 0.017,mean_rel โ 10%(per element), contributing only~0.98to the final logits (argmax-preserving). - The dominant INT4 weight error (
~12.6max logit delta) is unchanged from the baseline INT4 build. - fp16 vision vs fp32 vision: see
validation_report.json.
Generation mode: use
slow(autoregressive, greedy). The earlier prefill tail positions used by thefast/MTP path diverge under INT4 and are not relied upon here.
Intended use
Open-vocabulary detection / visual grounding: given an image and a category prompt
(Locate all the instances that matches the following description: <category>.), the model emits
<ref>label</ref><box>x1 y1 x2 y2</box> with coordinates normalized to 0โ1000.
KV-cache graph for in-browser use
onnx/language_tail_kv_int4.onnx (+.data, ~1.65 GB) is the KV-cache version of the language
tail used by the live demo. It takes inputs_embeds (+ input_ids for the plain-causal mask,
position_ids, and 36ร2 past_key/value GQA tensors [1,2,seq,128]) and returns logits for the
last position plus present_key/value. Prefill passes length-0 past; decode passes the growing
cache. This makes autoregressive decoding ~13ร faster than the cache-less graph. Validated: prefill
(empty past) โ cached decode reproduces <ref>label</ref><box>โฆ</box> detections; next-token argmax
matches the fp32 torch model. See kv_validation_report.json.
Live in-browser demo (WebGPU): https://huggingface.co/spaces/Reza2kn/LocateAnything-3B-WebGPU
Source model & license: nvidia/LocateAnything-3B.
- Downloads last month
- 103