improved
Browse files
README.md
CHANGED
@@ -51,6 +51,8 @@ weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-bio-fitted",
|
|
51 |
filename="weights.safetensors")
|
52 |
variables = load_file(weights_path)
|
53 |
variables = flax.traverse_util.unflatten_dict(variables, sep=".")
|
|
|
|
|
54 |
```
|
55 |
|
56 |
#### 3.2. Using `mgspack`
|
@@ -59,10 +61,9 @@ weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-bio-fitted",
|
|
59 |
filename="weights.msgpack")
|
60 |
with open(weights_path, "rb") as f:
|
61 |
variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read())
|
|
|
62 |
state = variables["state"]
|
63 |
params = variables["params"]
|
64 |
-
state = jax.tree_util.tree_map(lambda x: jnp.array(x), state)
|
65 |
-
params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)
|
66 |
```
|
67 |
|
68 |
### 4. Use the model
|
|
|
51 |
filename="weights.safetensors")
|
52 |
variables = load_file(weights_path)
|
53 |
variables = flax.traverse_util.unflatten_dict(variables, sep=".")
|
54 |
+
state = variables["state"]
|
55 |
+
params = variables["params"]
|
56 |
```
|
57 |
|
58 |
#### 3.2. Using `mgspack`
|
|
|
61 |
filename="weights.msgpack")
|
62 |
with open(weights_path, "rb") as f:
|
63 |
variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read())
|
64 |
+
variables = jax.tree_util.tree_map(lambda x: jnp.array(x), variables)
|
65 |
state = variables["state"]
|
66 |
params = variables["params"]
|
|
|
|
|
67 |
```
|
68 |
|
69 |
### 4. Use the model
|