File size: 2,298 Bytes
efeb3df
 
636385d
 
 
 
efeb3df
636385d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36ca4d6
636385d
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
---
license: apache-2.0
tags:
- jax
- rl
- jumanji
---

# BinPack-V2
This model is trained on the Jumanji BinPack environment


**Developed by:** InstaDeep

### Model Sources

<!-- Provide the basic links for the model. -->

- **Repository:** [Jumanji](https://github.com/instadeepai/jumanji)
- **Paper:** TBD

### How to use

[Notebook](#)

Go to the jumanji repo for the primary model and requirements. Clone the repo and navigate to the root directory.

```
pip install -e .
```

Below is an example script for loading and running the Jumanji model

```python
import pickle
import joblib

import jax
from hydra import compose, initialize
from huggingface_hub import hf_hub_download


from jumanji.training.setup_train import setup_agent, setup_env
from jumanji.training.utils import first_from_device

# initialise the config
with initialize(version_base=None, config_path="jumanji/training/configs"):
    cfg = compose(config_name="config.yaml", overrides=["env=snake", "agent=a2c"])

# get model state from HF
REPO_ID = "InstaDeepAI/jumanji-binpack-v2-a2c-benchmark"
FILENAME = "BinPack-v2_training_state"

model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

with open(model_weights,"rb") as f:
    training_state = pickle.load(f)

params = first_from_device(training_state.params_state.params)
env = setup_env(cfg).unwrapped
agent = setup_agent(cfg, env)
policy = jax.jit(agent.make_policy(params.actor, stochastic = False))

# rollout a few episodes
NUM_EPISODES = 10

states = []
key = jax.random.PRNGKey(cfg.seed)
for episode in range(NUM_EPISODES):
    key, reset_key = jax.random.split(key) 
    state, timestep = jax.jit(env.reset)(reset_key)
    while not timestep.last():
        key, action_key = jax.random.split(key)
        observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
        action, _ = policy(observation, action_key)
        state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0))
        states.append(state)
    # Freeze the terminal frame to pause the GIF.
    for _ in range(10):
        states.append(state)

# animate a GIF
env.animate(states, interval=150).save("./binpack.gif")

# save PNG
import matplotlib.pyplot as plt
%matplotlib inline
env.render(states[117])
plt.savefig("connector.png", dpi=300)

```