|
---
|
|
license: apache-2.0
|
|
tags:
|
|
- jax
|
|
- safetensors
|
|
---
|
|
|
|
# Baseline PerceptNet
|
|
|
|
## Model Description
|
|
|
|
## How to use it
|
|
|
|
### Install the model's package from source:
|
|
```
|
|
git clone https://github.com/Jorgvt/paramperceptnet.git
|
|
cd paramperceptnet
|
|
pip install -e .
|
|
```
|
|
|
|
### 1.Import necessary libraries:
|
|
|
|
```
|
|
import json
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
import flax
|
|
import orbax.checkpoint
|
|
from ml_collections import ConfigDict
|
|
|
|
from paramperceptnet.models import Baseline as PerceptNet
|
|
```
|
|
|
|
### 2.Download the configuration
|
|
|
|
```
|
|
config_path = hf_hub_download(repo_id="Jorgvt/ppnet-baseline",
|
|
filename="config.json")
|
|
with open(config_path, "r") as f:
|
|
config = ConfigDict(json.load(f))
|
|
```
|
|
|
|
### 3. Download the weights
|
|
|
|
#### 3.1. Using `safetensors`
|
|
|
|
```
|
|
from safetensors.flax import load_file
|
|
|
|
weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-baseline",
|
|
filename="weights.safetensors")
|
|
variables = load_file(weights_path)
|
|
variables = flax.traverse_util.unflatten_dict(variables, sep=".")
|
|
params = variables["params"]
|
|
```
|
|
|
|
#### 3.2. Using `mgspack`
|
|
```
|
|
weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-fully-trained",
|
|
filename="weights.msgpack")
|
|
with open(weights_path, "rb") as f:
|
|
variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read())
|
|
variables = jax.tree_util.tree_map(lambda x: jnp.array(x), variables)
|
|
params = variables["params"]
|
|
```
|
|
|
|
### 4. Use the model
|
|
|
|
```
|
|
from jax import numpy as jnp
|
|
pred = model.apply({"params": params}, jnp.ones((1,384,512,3)))
|
|
```
|
|
|