LocateAnything-3B โ€” ONNX WebGPU (INT4 + 4-bit embeddings)

In-browser (onnxruntime-web / WebGPU) build of nvidia/LocateAnything-3B, a visual-grounding / open-vocabulary detector. The language tower is weight-only INT4 and the embedding table is true group-wise 4-bit.

Why this repo exists

The naive "4-bit ONNX" of this model was ~3GB because the model has tied word embeddings (vocab 152681 ร— hidden 2048 = 1.25GB in fp32). ORT's MatMulNBits INT4 quantizer compresses the tied lm_head MatMul but leaves the input-embedding Gather at full fp32 โ€” so 1.25GB of fp32 embeddings stayed in the package.

This build fixes that with a custom quantized embedding gather:

  1. The language graph was surgically rewired to consume inputs_embeds directly (the fp32 embedding Gather and its 1.25GB initializer are removed). It still takes input_ids (used only by the SDLM block-mask == comparisons, not a Gather) and visual_features (spliced at the image token).
  2. The embedding table ships as a group-wise symmetric INT4 blob ((q-8)ยทscale, block size 32): embed_tokens_int4_packed.bin (uint8 nibble-packed) + embed_tokens_int4_scales.bin (fp16).
  3. The browser does the gather + dequant in JS to build inputs_embeds, then runs the INT4 language graph.

Files (browser-facing)

File Size Notes
onnx/vision_mlp.onnx (+.data) ~1.73 GB MoonViT + projector, fp32 (see note below)
onnx/language_tail_int4.onnx (+.data) ~1.69 GB Qwen2 language tower + tied lm_head, weight-only INT4 (block 128)
onnx/embed_tokens_int4_packed.bin ~156 MB INT4 embedding table, uint8 nibble-packed [152681, 1024]
onnx/embed_tokens_int4_scales.bin ~19.5 MB fp16 group scales [152681, 64]
onnx/embed_tokens_int4_meta.json โ€” layout / dequant scheme
web_config.json โ€” runtime wiring, token ids, tail size

Total browser payload โ‰ˆ 3.6 GB. The big win here is the language side: the embedding table dropped from 1.25 GB fp32 โ†’ 176 MB INT4 and the language tail from 2.9 GB โ†’ 1.69 GB.

Vision precision note. The vision tower's linears export as ONNX Gemm / dynamic MatMul, which ORT's MatMulNBits INT4 quantizer cannot compress, so it ships fp32. fp16 conversion is blocked by the explicit .float() Cast islands in the ONNX-friendly MoonViT RoPE patch (post-hoc fp16 conversion produces type clashes; native fp16 export hits a torch/MPS expand_as+float64 limitation). A native mixed-precision vision re-export (Conv in fp32, rest fp16) is the planned follow-up to cut this to ~0.9 GB.

Embedding gather / dequant (JS reference)

row   = packed[token_id]                 // uint8[hidden/2]
low   = row & 0x0F ; high = (row >> 4)   // two nibbles per byte (low = even idx, high = odd idx)
q     = interleave(low, high)            // uint4[hidden], values 0..15
emb   = (q - 8) * scales[token_id][j/32] // fp32[hidden]; one scale per 32-wide group

The language graph then splices visual_features over the image-token positions and applies the SDLM block mask from input_ids.

Validation

Validated against the fp32 PyTorch model on the sample image (slow / autoregressive mode):

  • Next-token argmax matches PyTorch exactly (token 151672 = <ref> start).
  • INT4 embedding gather error vs fp32 embeddings: max_abs 0.017, mean_rel โ‰ˆ 10% (per element), contributing only ~0.98 to the final logits (argmax-preserving).
  • The dominant INT4 weight error (~12.6 max logit delta) is unchanged from the baseline INT4 build.
  • fp16 vision vs fp32 vision: see validation_report.json.

Generation mode: use slow (autoregressive, greedy). The earlier prefill tail positions used by the fast/MTP path diverge under INT4 and are not relied upon here.

Intended use

Open-vocabulary detection / visual grounding: given an image and a category prompt (Locate all the instances that matches the following description: <category>.), the model emits <ref>label</ref><box>x1 y1 x2 y2</box> with coordinates normalized to 0โ€“1000.

KV-cache graph for in-browser use

onnx/language_tail_kv_int4.onnx (+.data, ~1.65 GB) is the KV-cache version of the language tail used by the live demo. It takes inputs_embeds (+ input_ids for the plain-causal mask, position_ids, and 36ร—2 past_key/value GQA tensors [1,2,seq,128]) and returns logits for the last position plus present_key/value. Prefill passes length-0 past; decode passes the growing cache. This makes autoregressive decoding ~13ร— faster than the cache-less graph. Validated: prefill (empty past) โ†’ cached decode reproduces <ref>label</ref><box>โ€ฆ</box> detections; next-token argmax matches the fp32 torch model. See kv_validation_report.json.

Live in-browser demo (WebGPU): https://huggingface.co/spaces/Reza2kn/LocateAnything-3B-WebGPU

Source model & license: nvidia/LocateAnything-3B.

Downloads last month
103
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for Reza2kn/LocateAnything-3B-ONNX-WebGPU-INT4

Base model

Qwen/Qwen2.5-3B
Quantized
(3)
this model

Space using Reza2kn/LocateAnything-3B-ONNX-WebGPU-INT4 1