Needle Transit — ONNX

ONNX export of rasaboun/needle-transit for in-browser / on-device inference via onnxruntime and onnxruntime-web. Same 26M-parameter transit extractor, runtime-portable.

Files

File Purpose
encoder.onnx Encoder. Input input_ids:(B,T) → output encoder_out:(B,T,512). Single pass.
decoder_step.onnx One decoder step with explicit past-KV in / present-KV out. Run in a loop.
needle.model SentencePiece BPE tokenizer (vocab 8192, byte_fallback=True, identity normalization). Loadable by sentencepiece-js / @huggingface/transformers.
tokenizer-specials.json Special-token ids: pad=0, eos=1, bos=2, tool_call=4, tools=5.
needle_torch.config.json Model dims (d_model=512, heads 8/4kv, enc×12, dec×8, max_seq_len=1024).

Inference flow

  1. Tokenize query + tools with needle.model, framing with the special tokens above.
  2. Run encoder.onnx once → encoder_out.
  3. Loop decoder_step.onnx greedily from bos, passing the KV cache (present-KV → past-KV) and feeding back each emitted token, until eos.
  4. Decode the token stream → JSON tool call (or empty = refusal).

The decoder is exported as a single step with past/present KV as graph I/O — the host calls it in a loop, enabling streaming and avoiding ONNX symbolic control flow.

Reference the JAX inference path in the official cactus-compute/needle package for the exact tokenization framing and decode loop.

Usage (Python / onnxruntime)

import onnxruntime as ort
import sentencepiece as spm

enc = ort.InferenceSession("encoder.onnx")
dec = ort.InferenceSession("decoder_step.onnx")
sp  = spm.SentencePieceProcessor(model_file="needle.model")

# ... tokenize query+tools, run enc once, loop dec until eos ...

For the browser, swap onnxruntime for onnxruntime-web (WASM/WebGPU backend); the file layout is identical.

Examples

Example query → tool call (verified — ONNX output is token-identical to the JAX checkpoint on these):

Query Output
Itinéraire de Bastille à Nation [{"name":"search_itinerary","arguments":{"origin":"Bastille","destination":"Nation"}}]
De Issy à Charles-de-Gaulle, départ 14h [{"name":"search_itinerary","arguments":{"origin":"Issy","destination":"Charles-de-Gaulle","time_human":"départ 14h","time_mode":"depart_at"}}]
How do I get from Gare du Nord to La Défense? [{"name":"search_itinerary","arguments":{"origin":"Gare du Nord","destination":"La Défense"}}]
Prochain métro à Bastille ligne 1 ? [{"name":"get_next_arrivals","arguments":{"station":"Bastille","line":"1"}}]
prochains passages à Châtelet [{"name":"get_next_arrivals","arguments":{"station":"Châtelet"}}]
cmt aller a chatelet depuis nation [{"name":"search_itinerary","arguments":{"origin":"nation","destination":"chatelet"}}]
Quel temps fait-il ? []

Results

Functionally equivalent to the JAX checkpoint — the ONNX export produces token-identical tool calls on the examples above (verified end-to-end). Dataset and evaluation: coming soon.

Finetuning

Finetune on your own tools with the customized scripts in github.com/Rasaboun/needle-transit (tunable LR/Muon-LR, per-field loss weighting, metrics logging), then re-export to ONNX with the toolkit below.

Provenance & reproduction

Upstream Needle is JAX/Flax, so torch.onnx.export can't run against it directly. These artifacts were produced via a port-and-copy pipeline: reimplement the Simple Attention Network in PyTorch, copy weights tensor-by-tensor from the Flax checkpoint, verify Flax↔PyTorch↔ONNX parity, then export encoder + decoder-step.

The conversion scripts live in onnx-community/needle-onnx (convert_weights.py, export_onnx.py, verify_parity.py, the PyTorch port, and PORTING.md). They are parametric on the source checkpoint — re-export any Needle finetune with the same recipe:

# get the conversion toolkit
hf download onnx-community/needle-onnx --local-dir needle-onnx && cd needle-onnx

# 1. Flax checkpoint → PyTorch state_dict
uv run python convert_weights.py --ckpt-repo rasaboun/needle-transit --ckpt-file needle-transit.pkl
# 2. verify the port matches upstream (< 1e-3)
uv run python verify_port_parity.py
# 3. export encoder + decoder-step to ONNX
uv run python export_onnx.py
# 4. verify ONNX ↔ PyTorch ↔ native generate()
uv run python verify_parity.py --ckpt-repo rasaboun/needle-transit --ckpt-file needle-transit.pkl

License & attribution

MIT. Exported from rasaboun/needle-transit, itself fine-tuned from Cactus-Compute/needle (© Cactus Compute, MIT).

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for rasaboun/needle-transit-onnx

Quantized
(1)
this model