File size: 1,663 Bytes
3867743 bc344f8 3867743 bc344f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
---
license: apache-2.0
tags:
- jax
- safetensors
---
# Baseline PerceptNet
## Model Description
## How to use it
### Install the model's package from source:
```
git clone https://github.com/Jorgvt/paramperceptnet.git
cd paramperceptnet
pip install -e .
```
### 1.Import necessary libraries:
```
import json
from huggingface_hub import hf_hub_download
import flax
import orbax.checkpoint
from ml_collections import ConfigDict
from paramperceptnet.models import Baseline as PerceptNet
```
### 2.Download the configuration
```
config_path = hf_hub_download(repo_id="Jorgvt/ppnet-baseline",
filename="config.json")
with open(config_path, "r") as f:
config = ConfigDict(json.load(f))
```
### 3. Download the weights
#### 3.1. Using `safetensors`
```
from safetensors.flax import load_file
weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-baseline",
filename="weights.safetensors")
variables = load_file(weights_path)
variables = flax.traverse_util.unflatten_dict(variables, sep=".")
params = variables["params"]
```
#### 3.2. Using `mgspack`
```
weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-fully-trained",
filename="weights.msgpack")
with open(weights_path, "rb") as f:
variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read())
variables = jax.tree_util.tree_map(lambda x: jnp.array(x), variables)
params = variables["params"]
```
### 4. Use the model
```
from jax import numpy as jnp
pred = model.apply({"params": params}, jnp.ones((1,384,512,3)))
```
|