gemma-4-31B-it-FP8

Fine-grained FP8 (block-quantized) version of google/gemma-4-31B-it, produced with transformers' FineGrainedFP8Config. Memory footprint drops from ~62GB (BF16) to ~33GB.

The vision tower and lm_head are kept in BF16 — see "Why these settings" below.

Recipe

from transformers import FineGrainedFP8Config, AutoModelForImageTextToText, AutoTokenizer, AutoProcessor
import torch

model_id = "google/gemma-4-31B-it"

quant_config = FineGrainedFP8Config(
    # block_size=64, not the default 128: unlike the sibling MoE checkpoint
    # (gemma-4-26B-A4B-it), every LM dim here -- hidden_size=5376,
    # intermediate_size=21504 -- IS divisible by 128, so the default block_size
    # loads and runs without any shape error. But it silently produces NaNs
    # starting in layer 1's self_attn.q_proj that corrupt every later layer and
    # eventually crash generate()'s sampling step with a CUDA device-side
    # assert ("probability tensor contains either inf, nan or element < 0").
    # 64x64 blocks give the FP8 kernel finer-grained per-block scales and the
    # NaNs disappear -- verified clean logits/hidden_states and correct
    # generation with block_size=64. This is a numerical-precision/kernel
    # issue, not a shape mismatch: passing the shape-divisibility check is NOT
    # sufficient evidence that a block size is safe for this model, always
    # check actual output values (not just that loading/forward "succeeds").
    weight_block_size=(64, 64),
    # model.vision_tower's MLP intermediate_size is 4304 = 16*269 (269 is
    # prime), so it isn't divisible by 64 (or 128) -- excluded explicitly
    # rather than relying on the quantizer's shape-mismatch fallback (which
    # silently leaves a layer in BF16 if its shape doesn't tile, but still
    # wraps it in the FP8 module class). lm_head is excluded too, the
    # conventional choice for numerical stability of the output projection.
    modules_to_not_convert=["model.vision_tower", "lm_head"],
)

model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    quantization_config=quant_config,
    dtype="auto",
    device_map="auto",
)
tok = AutoTokenizer.from_pretrained(model_id)
proc = AutoProcessor.from_pretrained(model_id)

model.save_pretrained("gemma-4-31B-it-FP8")
tok.save_pretrained("gemma-4-31B-it-FP8")
proc.save_pretrained("gemma-4-31B-it-FP8")

Why these settings

  • weight_block_size=(64, 64): the default (128, 128) block divides every language-model dimension here cleanly (hidden_size=5376, intermediate_size=21504), so it loads and runs with no error at all — but it produces NaNs in the first couple of attention layers that silently propagate through the rest of the network. This was caught only by checking logits/hidden_states for NaN/Inf directly, not from a crash at load time (the crash only shows up later, during sampling in generate()). 64x64 blocks resolve it. Don't assume a block size is numerically safe just because the model loads and a forward pass doesn't error.

  • modules_to_not_convert=["model.vision_tower", "lm_head"]: the vision tower's MLP intermediate_size (4304) isn't divisible by 64 or 128, and is comparatively small next to the 31B-parameter backbone, so quantizing it buys little memory at the risk of degrading image understanding. lm_head is excluded for the usual numerical-stability reasons.

    Note the module path must be the full dotted path as it appears in model.named_modules() ("model.vision_tower", not just "vision_tower") — should_convert_module only matches a pattern that is a prefix of, or suffix to, the full path.

Usage

from transformers import AutoModelForImageTextToText, AutoTokenizer
import torch

model = AutoModelForImageTextToText.from_pretrained(
    "hugg1ngfac3/gemma-4-31B-it-FP8",
    dtype="auto",
    device_map="auto",
)
tok = AutoTokenizer.from_pretrained("hugg1ngfac3/gemma-4-31B-it-FP8")

msgs = [{"role": "user", "content": "Say hello in one short sentence."}]
inputs = tok.apply_chat_template(msgs, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
out = model.generate(**inputs, max_new_tokens=30)
print(tok.decode(out[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True))

Requires kernels>=0.12.0,<0.13 for the FP8 kernel (pip install -U "kernels<0.13").

Downloads last month
-
Safetensors
Model size
31B params
Tensor type
F32
·
BF16
·
F8_E4M3
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for hugg1ngfac3/gemma-4-31B-it-FP8

Quantized
(243)
this model