SAM 2.1 (Hiera-Tiny) mask decoder β€” LiteRT GPU

On-device LiteRT / TFLite conversion of the prompt-conditioned mask decoder of SAM 2.1 Hiera-Tiny (Meta, Apache-2.0), running fully on the mobile GPU via the LiteRT CompiledModel API (ML Drift / LITERT_CL delegate).

This is the lightweight, per-click half of the SAM 2 image path. Pair it with the SAM 2.1 Hiera-Tiny image encoder (run once per image, ~7 ms): the encoder produces the multi-scale feature pyramid, and this decoder turns a point prompt into segmentation masks per tap (a few ms each) β€” interactive "tap to segment".

Task Mask decoder for promptable segmentation (SAM 2 image path)
Architecture 2-layer two-way transformer (token↔image cross-attention) + mask up-sampler
Inputs image_embeddings [1,256,64,64], sparse_prompt [1,2,256], feat_s1 [1,64,128,128], feat_s0 [1,32,256,256]
Outputs pred_masks [1,3,256,256] (logits, 3 multimask candidates), iou_scores [1,3]
Precision / size FP16, 17 MB
Device Pixel 8a β€” GPU-resident (358/358 LITERT_CL) but run on CPU for correct masks (see below)
Op set banned ops = NONE, >4-D tensors = 0 (BATCH_MATMUL Γ—15, SOFTMAX Γ—7, GELU Γ—2, CONV_2D Γ—2)

⚠ Residency β‰  correctness. This decoder fully delegates to the LiteRT GPU (358/358 LITERT_CL nodes), but on the Pixel 8a its GPU fp16 output is numerically wrong β€” a face tap that the CPU decoder segments at IoU β‰ˆ 0.62 collapses to β‰ˆ 0.10 with the mask on the background under the GPU delegate. The companion encoder's GPU output is fine (encoder-GPU + decoder-CPU matches all-CPU exactly). It is not LayerNorm (plain vs. overflow-safe give the same wrong GPU result); the offending op is still being localized with a per-op GPU dump. Run this decoder on CPU (it is tiny and fast); the heavy image encoder is the part that benefits from the GPU.

Pipeline (how the inputs are produced)

RGB image ──> image encoder (run once) ──> image_embeddings[1,256,64,64], feat_s1[1,64,128,128], feat_s0[1,32,256,256]
tap (x,y in 1024-space) ──> prompt encode (host-side, see below) ──> sparse_prompt[1,2,256]
                                              β”‚
       image_embeddings + feat_s0/s1 + sparse_prompt ──> THIS decoder ──> 3 masks + 3 IoU
       pick argmax(IoU) ──> upsample 256Γ—256 logits to image size ──> threshold > 0 ──> overlay

The decoder uses the encoder variant that already folds conv_s0 / conv_s1 + no_memory so its outputs are directly decoder-ready (no host reshaping between the two models).

Host-side prompt encoding (single positive point)

The tiny point→token step (a sin/cos positional encoding) is done on the host to keep the GPU graph sin/cos-free. For a positive click (x, y) in 1024×1024 model space, with the bundled constants posmat [2,128], point_embed[1] [256], not_a_point [256]:

c      = (([x, y]) + 0.5) / 1024          # normalize, half-pixel shift
c      = 2*c - 1
coord  = 2*pi * (c @ posmat)              # [128]
token0 = concat(sin(coord), cos(coord)) + point_embed[1]   # the positive point
token1 = not_a_point                      # the padding point
sparse_prompt = [[token0, token1]]        # [1, 2, 256]

This matches the upstream Sam2PromptEncoder to ~3.7e-7.

GPU-clean conversion (what was re-authored)

Converted with litert-torch, model-side rewrites only β€” no converter patch, each weights-faithful:

  1. Two-way attention (Γ—7): re-expressed as 3-D batched SDPA [heads, N, d] (a 4-D SDPA makes the delegate emit a BROADCAST_TO).
  2. Mask up-sampler ConvTranspose2d (Γ—2): replaced with the exact zero-stuff + Conv2d identity (TRANSPOSE_CONV is rejected on Pixel 8a; this is numerically identical, not a bilinear approximation).
  3. Mask head: the hyper_in @ upscaled mask projection is kept ≀4-D (the upstream [1,1,4,256,256] 5-D tensor is collapsed; batch/point-batch are 1).
  4. LayerNorm (Γ—9): scale-before-square SafeLayerNorm (fp16-overflow-safe, mathematically identical).
  5. Constants baked: image_positional_embeddings and the no-mask dense prompt are baked as buffers.
  6. Multimask path: static slice [1:] of the 3 candidate masks β€” no dynamic-stability argmax / gather / where.

Fidelity (honest)

Eager re-authoring is numerically exact (cos = 1.000). End-to-end through the two FP16 tflite models (encoder β†’ host prompt-encode β†’ decoder) vs the PyTorch reference, for a center click:

Metric value
mask logits cosine 0.999999
binary mask IoU (threshold 0) 0.99964
IoU-score head ref [0.936, 0.022, 0.399] vs got [0.936, 0.022, 0.399]

The deepest 64Γ—64 image embedding drifts slightly on the GPU (true-fp16 deep attention; see the encoder card). Mask boundaries are carried by the near-exact high-resolution features, so mask quality holds.

Usage (Android / LiteRT CompiledModel)

// once per image β€” encoder on GPU
val enc = CompiledModel.create(assets, "sam2_image_encoder_v2_fp16.tflite", Options(Accelerator.GPU), null)
// per tap β€” decoder on CPU (GPU-resident but fp16-incorrect on device; see the residency note above)
val dec = CompiledModel.create(assets, "sam2_mask_decoder_fp16.tflite", Options(Accelerator.CPU), null)
// dec inputs (by index): 0 image_embeddings[1,256,64,64], 1 sparse[1,2,256], 2 feat_s1[1,64,128,128], 3 feat_s0[1,32,256,256]
// dec outputs: pred_masks[1,3,256,256] logits, iou_scores[1,3]  -> pick argmax(iou), upsample, threshold 0

Training data & PII

SAM 2 was trained by Meta on SA-1B (licensed photos) and SA-V (licensed videos) with model-in-the-loop mask annotation. No new training was performed for this conversion β€” it is a weights-faithful format change of the public facebook/sam2.1-hiera-tiny checkpoint. Because the source data is real-world imagery it may incidentally contain people, faces, vehicles, signage and other PII; no PII was deliberately collected and this conversion adds none. Apply your own content/PII filtering as appropriate. See the SAM 2 release and paper for full dataset details.

License

Apache-2.0, inherited from the upstream SAM 2.1. This is a format conversion; all credit to the original authors (Meta AI).

Downloads last month
15
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for litert-community/SAM2.1-Hiera-Tiny-Mask-Decoder

Finetuned
(5)
this model

Paper for litert-community/SAM2.1-Hiera-Tiny-Mask-Decoder