File size: 1,663 Bytes
3867743
 
bc344f8
 
 
3867743
bc344f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
---

license: apache-2.0
tags:
- jax
- safetensors
---


# Baseline PerceptNet

## Model Description

## How to use it

### Install the model's package from source:
```

git clone https://github.com/Jorgvt/paramperceptnet.git

cd paramperceptnet

pip install -e .

```

### 1.Import necessary libraries:

```

import json



from huggingface_hub import hf_hub_download

import flax

import orbax.checkpoint

from ml_collections import ConfigDict



from paramperceptnet.models import Baseline as PerceptNet

```

### 2.Download the configuration

```

config_path = hf_hub_download(repo_id="Jorgvt/ppnet-baseline",

                              filename="config.json")

with open(config_path, "r") as f:

    config = ConfigDict(json.load(f))

```

### 3. Download the weights

#### 3.1. Using `safetensors`

```

from safetensors.flax import load_file



weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-baseline",

                               filename="weights.safetensors")

variables = load_file(weights_path)

variables = flax.traverse_util.unflatten_dict(variables, sep=".")

params = variables["params"]

```

#### 3.2. Using `mgspack`
```

weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-fully-trained",

                               filename="weights.msgpack")

with open(weights_path, "rb") as f:

    variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read())

variables = jax.tree_util.tree_map(lambda x: jnp.array(x), variables)

params = variables["params"]

```

### 4. Use the model

```

from jax import numpy as jnp

pred = model.apply({"params": params}, jnp.ones((1,384,512,3)))

```