Jorgvt commited on
Commit
1387cc6
·
1 Parent(s): 6b740f5
Files changed (1) hide show
  1. README.md +3 -2
README.md CHANGED
@@ -51,6 +51,8 @@ weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-bio-fitted",
51
  filename="weights.safetensors")
52
  variables = load_file(weights_path)
53
  variables = flax.traverse_util.unflatten_dict(variables, sep=".")
 
 
54
  ```
55
 
56
  #### 3.2. Using `mgspack`
@@ -59,10 +61,9 @@ weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-bio-fitted",
59
  filename="weights.msgpack")
60
  with open(weights_path, "rb") as f:
61
  variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read())
 
62
  state = variables["state"]
63
  params = variables["params"]
64
- state = jax.tree_util.tree_map(lambda x: jnp.array(x), state)
65
- params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)
66
  ```
67
 
68
  ### 4. Use the model
 
51
  filename="weights.safetensors")
52
  variables = load_file(weights_path)
53
  variables = flax.traverse_util.unflatten_dict(variables, sep=".")
54
+ state = variables["state"]
55
+ params = variables["params"]
56
  ```
57
 
58
  #### 3.2. Using `mgspack`
 
61
  filename="weights.msgpack")
62
  with open(weights_path, "rb") as f:
63
  variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read())
64
+ variables = jax.tree_util.tree_map(lambda x: jnp.array(x), variables)
65
  state = variables["state"]
66
  params = variables["params"]
 
 
67
  ```
68
 
69
  ### 4. Use the model