metadata
license: apache-2.0
tags:
- jax
- safetensors
Parametric PerceptNet Bio-Fitted
Model Description
How to use it
Install the model's package from source:
git clone https://github.com/Jorgvt/paramperceptnet.git
cd paramperceptnet
pip install -e .
1.Import necessary libraries:
import json
from huggingface_hub import hf_hub_download
import flax
import orbax.checkpoint
from ml_collections import ConfigDict
from paramperceptnet.models import PerceptNet
2.Download the configuration
config_path = hf_hub_download(repo_id="Jorgvt/ppnet-bio-fitted",
filename="config.json")
with open(config_path, "r") as f:
config = ConfigDict(json.load(f))
3. Download the weights
weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-bio-fitted",
filename="weights.msgpack")
with open(weights_path, "rb") as f:
variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read())
state = variables["state"]
params = variables["params"]
state = jax.tree_util.tree_map(lambda x: jnp.array(x), state)
params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)
4. Use the model
from jax import numpy as jnp
pred = model.apply({"params": params, **state}, jnp.ones((1,384,512,3)))