how to load this .npz model

#3
by bsmani - opened

hi team please tell me how to load this .npz model and do the inference

Google org
edited Aug 5

Hi @bsmani , To load the .npz model and perform inference in the paligemma-3b-ft-ai2d-224-jax model, i have tried the steps you can follow in this google colab link.

  1. Load the model:
  • Use the jax.tree_util.tree_map function to load the weights from the .npz file.
  • Modify the state dictionary to match the shape of the input data.
  1. Perform inference:
  • Call the predict_step function with the processed input and the loaded model.
  • Extract the logits and apply softmax to obtain probability distributions.

Kindly try these steps and let me know if you are facing any issue. Thank you.

Sign up or log in to comment