d-byrne commited on
Commit
636385d
1 Parent(s): 4c5188b

updated readme

Browse files
Files changed (1) hide show
  1. README.md +88 -0
README.md CHANGED
@@ -1,3 +1,91 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ tags:
4
+ - jax
5
+ - rl
6
+ - jumanji
7
  ---
8
+
9
+ # BinPack-V2
10
+ This model is trained on the Jumanji BinPack environment
11
+
12
+
13
+ **Developed by:** InstaDeep
14
+
15
+ ### Model Sources
16
+
17
+ <!-- Provide the basic links for the model. -->
18
+
19
+ - **Repository:** [Jumanji](https://github.com/instadeepai/jumanji)
20
+ - **Paper:** TBD
21
+
22
+ ### How to use
23
+
24
+ [Notebook](#)
25
+
26
+ Go to the jumanji repo for the primary model and requirements. Clone the repo and navigate to the root directory.
27
+
28
+ ```
29
+ pip install -e .
30
+ ```
31
+
32
+ Below is an example script for loading and running the Jumanji model
33
+
34
+ ```python
35
+ import pickle
36
+ import joblib
37
+
38
+ import jax
39
+ from hydra import compose, initialize
40
+ from huggingface_hub import hf_hub_download
41
+
42
+
43
+ from jumanji.training.setup_train import setup_agent, setup_env
44
+ from jumanji.training.utils import first_from_device
45
+
46
+ # initialise the config
47
+ with initialize(version_base=None, config_path="jumanji/training/configs"):
48
+ cfg = compose(config_name="config.yaml", overrides=["env=snake", "agent=a2c"])
49
+
50
+ # get model state from HF
51
+ REPO_ID = "InstaDeepAI/jumanji-binpack-v2-a2c-benchmark"
52
+ FILENAME = "BinPack-v2_training_state"
53
+
54
+ model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
55
+
56
+ with open(model_weights,"rb") as f:
57
+ training_state = pickle.load(f)
58
+
59
+ params = first_from_device(training_state.params_state.params)
60
+ env = setup_env(cfg).unwrapped
61
+ agent = setup_agent(cfg, env)
62
+ policy = jax.jit(agent.make_policy(params.actor, stochastic = False))
63
+
64
+ # rollout a few episodes
65
+ NUM_EPISODES = 10
66
+
67
+ states = []
68
+ key = jax.random.PRNGKey(cfg.seed)
69
+ for episode in range(NUM_EPISODES):
70
+ key, reset_key = jax.random.split(key)
71
+ state, timestep = jax.jit(env.reset)(reset_key)
72
+ while not timestep.last():
73
+ key, action_key = jax.random.split(key)
74
+ observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
75
+ action, _ = policy(observation, action_key)
76
+ state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0))
77
+ states.append(state)
78
+ # Freeze the terminal frame to pause the GIF.
79
+ for _ in range(10):
80
+ states.append(state)
81
+
82
+ # animate a GIF
83
+ env.animate(states, interval=150).save("./snake.gif")
84
+
85
+ # save PNG
86
+ import matplotlib.pyplot as plt
87
+ %matplotlib inline
88
+ env.render(states[117])
89
+ plt.savefig("connector.png", dpi=300)
90
+
91
+ ```