Flax checkpoint with float16 dtype outputs nan

#21
by gigant - opened

Hello,

The Flax version of this T5 checkpoint with float16 dtype outputs exclusively nan. In float32 it works.

To reproduce the error:

from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
model = FlaxT5ForConditionalGeneration.from_pretrained("google/flan-t5-base", dtype="float16")

model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")

pad_token_id=model.config.pad_token_id
decoder_start_token_id=model.config.decoder_start_token_id


input_text = "translate English to German: How old are you?"

input_ids = tokenizer(input_text, return_tensors="np")

decoder_input_ids = shift_tokens_right_fn(
    input_ids["input_ids"], pad_token_id, decoder_start_token_id
)

outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0][0], skip_special_tokens=True))
print(model(**input_ids, decoder_input_ids=decoder_input_ids))

The generation results is just empty and the model output is filled with nan.

Is there a way to prevent that / a correct way to translate the float32 weights to float16 that will not give this error?

Thank you

Google org

@sanchit-gandhi , would you know what might be happening?

Hey @gigant - I ran the codesnippet with a few settings, and believe this phenomenon is an artefact of the float16 dynamic range, rather than an issue with the model. E.g. if we run the model with a dtype of jnp.float32 or jnp.bfloat16 (larger dynamic range) we get sensible outputs. If we switch to a dtype of jnp.float16 (smaller dynamic range), we get nan's as you have described => this suggests that the model requires the larger dynamic range to operate, and exceeds the dynamic range of float16. This might come as a result of it being trained in bfloat16 directly.

Note that the dtype argument only changes the dtype of the model computation (forward pass), not the parameters. This gets you memory and latency improvements since the model operations are run in a lower dtype, but the parameters themselves are un-changes.

If you want to change the dtype of the parameters, you can set:

model = FlaxT5ForConditionalGeneration.from_pretrained("google/flan-t5-base")
model.params = model.to_fp16(model.params)

=> this won't speed up the forward pass (since you still run it in a dtype of float32), but will give you a memory saving by downcasting the params and thus reducing their memory footprint. You can also downcast them to bfloat16 in much the same way:

model = FlaxT5ForConditionalGeneration.from_pretrained("google/flan-t5-base")
model.params = model.to_bf16(model.params)

TLDR: for fast inference and memory savings, set dtype=jnp.bfloat16. For just memory savings, downcast the params. For max performance, do both!

model = FlaxT5ForConditionalGeneration.from_pretrained("google/flan-t5-base", dtype=jnp.bfloat16)
model.params = model.to_bf16(model.params)

Thank you very much for the comprehensive answer! Its much clearer now

gigant changed discussion status to closed

Sign up or log in to comment