ppnet-bio-fitted / README.md
Jorgvt's picture
improved
1387cc6
metadata
license: apache-2.0
tags:
  - jax
  - safetensors

Parametric PerceptNet Bio-Fitted

Model Description

How to use it

Install the model's package from source:

git clone https://github.com/Jorgvt/paramperceptnet.git
cd paramperceptnet
pip install -e .

1.Import necessary libraries:

import json

from huggingface_hub import hf_hub_download
import flax
import orbax.checkpoint
from ml_collections import ConfigDict

from paramperceptnet.models import PerceptNet

2.Download the configuration

config_path = hf_hub_download(repo_id="Jorgvt/ppnet-bio-fitted",
                              filename="config.json")
with open(config_path, "r") as f:
    config = ConfigDict(json.load(f))

3. Download the weights

3.1. Using safetensors

from safetensors.flax import load_file

weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-bio-fitted",
                               filename="weights.safetensors")
variables = load_file(weights_path)
variables = flax.traverse_util.unflatten_dict(variables, sep=".")
state = variables["state"]
params = variables["params"]

3.2. Using mgspack

weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-bio-fitted",
                               filename="weights.msgpack")
with open(weights_path, "rb") as f:
    variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read())
variables = jax.tree_util.tree_map(lambda x: jnp.array(x), variables)
state = variables["state"]
params = variables["params"]

4. Use the model

from jax import numpy as jnp
pred = model.apply({"params": params, **state}, jnp.ones((1,384,512,3)))