umt5-xxl-mlx

MLX-converted UMT5-XXL encoder weights for Apple Silicon. Converted from google/umt5-xxl (Apache-2.0).

Runtime code: github.com/sb1992/mlx-umt5

What's in this repo

Two precast variant dirs โ€” pick one:

Folder Precision Disk Load peak Note
int8/ โญ per-channel-symmetric int8 ~6.3 GB 7.02 GB recommended
bf16/ bf16 ~11 GB 11.30 GB full-precision reference

Tokenizer files (tokenizer.json, tokenizer_config.json, special_tokens_map.json, spiece.model) are at the repo root, shared by both variants.

Each variant is a single precast safetensors file (model.safetensors) + config.json. The runtime detects the precision from the file's metadata header โ€” no flag needed.

Download only what you need (~6.3 GB, not ~18 GB). from_pretrained(...) fetches only the chosen variant + tokenizer. If you download manually, use --include (below) โ€” a bare git clone / full snapshot_download pulls both variants (~18 GB total).

Usage

from mlx_umt5 import from_pretrained, encode

loaded = from_pretrained("shraey/umt5-xxl-mlx", variant="int8")
embeds, masks = encode(loaded, ["A cinematic shot of a mountain at sunrise."])
# embeds[0]: (1, 512, 4096) fp32
# masks[0]:  (1, 512)       int32
# Download only the variant you need:
hf download shraey/umt5-xxl-mlx \
    --include "int8/*" tokenizer.json tokenizer_config.json \
    special_tokens_map.json spiece.model \
    --local-dir ./umt5-xxl-mlx

Parity

  • int8/ precast is byte-identical to the in-process per-channel-symmetric int8 path (max-abs-diff 0).
  • bf16/ precast is bit-identical to loading the HF fp32 shards and casting to bf16 (cosine 1.0).
  • int8 vs bf16 quality cosine: short 0.999631, 512-token 0.999059, unicode 0.998631 โ€” worst 0.99863, above the 0.995 gate.
  • MLX-bf16 vs torch-fp32 oracle: cosine 0.999924 / 0.999575 / 0.999900 across three prompt lengths.

Quantization note

Per-channel-symmetric int8 is the only safe scheme for UMT5-XXL. MLX group-affine quantization corrupts this model (cosine ~0.108). UMT5 has no attention output scaling and a monotonically growing residual across 24 layers โ€” group-affine error compounds with depth. The int8/ variant uses pcs only and is byte-identical to the in-process conversion.

The precast advantage

Standard HF shard loading casts ~21 GB of fp32 tensors at startup, creating a ~22 GB transient before any inference. These precast files eliminate that transient entirely โ€” the encoder loads directly into the target dtype via lazy mmap.

Attribution

Converted from google/umt5-xxl (Apache-2.0) and hosted as a convenience. Full credit to Google for the UMT5 model, architecture, and weights. This conversion does not change the license.

Downloads last month

-

Downloads are not tracked for this model. How to track
MLX
Hardware compatibility
Log In to add your hardware

Quantized

Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for shraey/umt5-xxl-mlx

Base model

google/umt5-xxl
Quantized
(6)
this model