--- license: apache-2.0 tags: - jax - safetensors --- # Baseline PerceptNet ## Model Description ## How to use it ### Install the model's package from source: ``` git clone https://github.com/Jorgvt/paramperceptnet.git cd paramperceptnet pip install -e . ``` ### 1.Import necessary libraries: ``` import json from huggingface_hub import hf_hub_download import flax import orbax.checkpoint from ml_collections import ConfigDict from paramperceptnet.models import Baseline as PerceptNet ``` ### 2.Download the configuration ``` config_path = hf_hub_download(repo_id="Jorgvt/ppnet-baseline", filename="config.json") with open(config_path, "r") as f: config = ConfigDict(json.load(f)) ``` ### 3. Download the weights #### 3.1. Using `safetensors` ``` from safetensors.flax import load_file weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-baseline", filename="weights.safetensors") variables = load_file(weights_path) variables = flax.traverse_util.unflatten_dict(variables, sep=".") params = variables["params"] ``` #### 3.2. Using `mgspack` ``` weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-fully-trained", filename="weights.msgpack") with open(weights_path, "rb") as f: variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read()) variables = jax.tree_util.tree_map(lambda x: jnp.array(x), variables) params = variables["params"] ``` ### 4. Use the model ``` from jax import numpy as jnp pred = model.apply({"params": params}, jnp.ones((1,384,512,3))) ```