vwxyzjn commited on
Commit
e835799
1 Parent(s): be2c761

pushing model

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ cleanba_ppo_envpool_procgen.cleanrl_model filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - MinerHard-v0
4
+ - deep-reinforcement-learning
5
+ - reinforcement-learning
6
+ - custom-implementation
7
+ library_name: cleanrl
8
+ model-index:
9
+ - name: PPO
10
+ results:
11
+ - task:
12
+ type: reinforcement-learning
13
+ name: reinforcement-learning
14
+ dataset:
15
+ name: MinerHard-v0
16
+ type: MinerHard-v0
17
+ metrics:
18
+ - type: mean_reward
19
+ value: 9.00 +/- 2.41
20
+ name: mean_reward
21
+ verified: false
22
+ ---
23
+
24
+ # (CleanRL) **PPO** Agent Playing **MinerHard-v0**
25
+
26
+ This is a trained model of a PPO agent playing MinerHard-v0.
27
+ The model was trained by using [CleanRL](https://github.com/vwxyzjn/cleanrl) and the most up-to-date training code can be
28
+ found [here](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/cleanba_ppo_envpool_procgen.py).
29
+
30
+ ## Get Started
31
+
32
+ To use this model, please install the `cleanrl` package with the following command:
33
+
34
+ ```
35
+ pip install "cleanrl[jax,envpool,atari]"
36
+ python -m cleanrl_utils.enjoy --exp-name cleanba_ppo_envpool_procgen --env-id MinerHard-v0
37
+ ```
38
+
39
+ Please refer to the [documentation](https://docs.cleanrl.dev/get-started/zoo/) for more detail.
40
+
41
+
42
+ ## Command to reproduce the training
43
+
44
+ ```bash
45
+ curl -OL https://huggingface.co/cleanrl/MinerHard-v0-cleanba_ppo_envpool_procgen-seed1/raw/main/cleanba_ppo_envpool_procgen.py
46
+ curl -OL https://huggingface.co/cleanrl/MinerHard-v0-cleanba_ppo_envpool_procgen-seed1/raw/main/pyproject.toml
47
+ curl -OL https://huggingface.co/cleanrl/MinerHard-v0-cleanba_ppo_envpool_procgen-seed1/raw/main/poetry.lock
48
+ poetry install --all-extras
49
+ python cleanba_ppo_envpool_procgen.py --distributed --track --save-model --upload-model --wandb-project-name cleanba --hf-entity cleanrl --env-id MinerHard-v0 --seed 1
50
+ ```
51
+
52
+ # Hyperparameters
53
+ ```python
54
+ {'actor_device_ids': [0],
55
+ 'actor_devices': ['gpu:0'],
56
+ 'anneal_lr': False,
57
+ 'async_batch_size': 20,
58
+ 'async_update': 3,
59
+ 'batch_size': 61440,
60
+ 'capture_video': False,
61
+ 'clip_coef': 0.2,
62
+ 'cuda': True,
63
+ 'distributed': True,
64
+ 'ent_coef': 0.01,
65
+ 'env_id': 'MinerHard-v0',
66
+ 'exp_name': 'cleanba_ppo_envpool_procgen',
67
+ 'gae_lambda': 0.95,
68
+ 'gamma': 0.999,
69
+ 'global_learner_decices': ['gpu:0', 'gpu:1', 'gpu:2', 'gpu:3'],
70
+ 'hf_entity': 'cleanrl',
71
+ 'learner_device_ids': [0],
72
+ 'learner_devices': ['gpu:0'],
73
+ 'learning_rate': 0.0005,
74
+ 'local_batch_size': 15360,
75
+ 'local_minibatch_size': 1920,
76
+ 'local_num_envs': 60,
77
+ 'local_rank': 0,
78
+ 'max_grad_norm': 0.5,
79
+ 'minibatch_size': 7680,
80
+ 'norm_adv': True,
81
+ 'num_envs': 240,
82
+ 'num_minibatches': 8,
83
+ 'num_steps': 256,
84
+ 'num_updates': 1627,
85
+ 'profile': False,
86
+ 'save_model': True,
87
+ 'seed': 1,
88
+ 'target_kl': None,
89
+ 'test_actor_learner_throughput': False,
90
+ 'torch_deterministic': True,
91
+ 'total_timesteps': 100000000,
92
+ 'track': True,
93
+ 'update_epochs': 3,
94
+ 'upload_model': True,
95
+ 'vf_coef': 0.5,
96
+ 'wandb_entity': None,
97
+ 'wandb_project_name': 'cleanba',
98
+ 'world_size': 4}
99
+ ```
100
+
cleanba_ppo_envpool_procgen.cleanrl_model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93910ef2db7a150b5bbc04a24549cb013d7af48064eb532ad6f1619df02eceaf
3
+ size 2507060
cleanba_ppo_envpool_procgen.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import time
5
+ import uuid
6
+ from collections import deque
7
+ from distutils.util import strtobool
8
+ from functools import partial
9
+ from typing import Sequence
10
+
11
+ os.environ[
12
+ "XLA_PYTHON_CLIENT_MEM_FRACTION"
13
+ ] = "0.6" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991
14
+ os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false " "intra_op_parallelism_threads=1"
15
+ import queue
16
+ import threading
17
+
18
+ import envpool
19
+ import flax
20
+ import flax.linen as nn
21
+ import gym
22
+ import jax
23
+ import jax.numpy as jnp
24
+ import numpy as np
25
+ import optax
26
+ from flax.linen.initializers import constant, orthogonal
27
+ from flax.training.train_state import TrainState
28
+ from tensorboardX import SummaryWriter
29
+
30
+
31
+ def parse_args():
32
+ # fmt: off
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
35
+ help="the name of this experiment")
36
+ parser.add_argument("--seed", type=int, default=1,
37
+ help="seed of the experiment")
38
+ parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
39
+ help="if toggled, `torch.backends.cudnn.deterministic=False`")
40
+ parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
41
+ help="if toggled, cuda will be enabled by default")
42
+ parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
43
+ help="if toggled, this experiment will be tracked with Weights and Biases")
44
+ parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
45
+ help="the wandb's project name")
46
+ parser.add_argument("--wandb-entity", type=str, default=None,
47
+ help="the entity (team) of wandb's project")
48
+ parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
49
+ help="weather to capture videos of the agent performances (check out `videos` folder)")
50
+ parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
51
+ help="whether to save model into the `runs/{run_name}` folder")
52
+ parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
53
+ help="whether to upload the saved model to huggingface")
54
+ parser.add_argument("--hf-entity", type=str, default="",
55
+ help="the user or org name of the model repository from the Hugging Face Hub")
56
+
57
+ # Algorithm specific arguments
58
+ parser.add_argument("--env-id", type=str, default="BigfishHard-v0",
59
+ help="the id of the environment")
60
+ parser.add_argument("--total-timesteps", type=int, default=100000000,
61
+ help="total timesteps of the experiments")
62
+ parser.add_argument("--learning-rate", type=float, default=5e-4,
63
+ help="the learning rate of the optimizer")
64
+ parser.add_argument("--local-num-envs", type=int, default=60,
65
+ help="the number of parallel game environments")
66
+ parser.add_argument("--async-batch-size", type=int, default=20,
67
+ help="the envpool's batch size in the async mode")
68
+ parser.add_argument("--num-steps", type=int, default=256,
69
+ help="the number of steps to run in each environment per policy rollout")
70
+ parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
71
+ help="Toggle learning rate annealing for policy and value networks")
72
+ parser.add_argument("--gamma", type=float, default=0.999,
73
+ help="the discount factor gamma")
74
+ parser.add_argument("--gae-lambda", type=float, default=0.95,
75
+ help="the lambda for the general advantage estimation")
76
+ parser.add_argument("--num-minibatches", type=int, default=8,
77
+ help="the number of mini-batches")
78
+ parser.add_argument("--update-epochs", type=int, default=3,
79
+ help="the K epochs to update the policy")
80
+ parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
81
+ help="Toggles advantages normalization")
82
+ parser.add_argument("--clip-coef", type=float, default=0.2,
83
+ help="the surrogate clipping coefficient")
84
+ parser.add_argument("--ent-coef", type=float, default=0.01,
85
+ help="coefficient of the entropy")
86
+ parser.add_argument("--vf-coef", type=float, default=0.5,
87
+ help="coefficient of the value function")
88
+ parser.add_argument("--max-grad-norm", type=float, default=0.5,
89
+ help="the maximum norm for the gradient clipping")
90
+ parser.add_argument("--target-kl", type=float, default=None,
91
+ help="the target KL divergence threshold")
92
+
93
+ parser.add_argument("--actor-device-ids", type=int, nargs="+", default=[0], # type is actually List[int]
94
+ help="the device ids that actor workers will use (currently only support 1 device)")
95
+ parser.add_argument("--learner-device-ids", type=int, nargs="+", default=[0], # type is actually List[int]
96
+ help="the device ids that learner workers will use")
97
+ parser.add_argument("--distributed", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
98
+ help="whether to use `jax.distirbuted`")
99
+ parser.add_argument("--profile", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
100
+ help="whether to call block_until_ready() for profiling")
101
+ parser.add_argument("--test-actor-learner-throughput", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
102
+ help="whether to test actor-learner throughput by removing the actor-learner communication")
103
+ args = parser.parse_args()
104
+ args.local_batch_size = int(args.local_num_envs * args.num_steps)
105
+ args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches)
106
+ args.num_updates = args.total_timesteps // args.local_batch_size
107
+ args.async_update = int(args.local_num_envs / args.async_batch_size)
108
+ assert len(args.actor_device_ids) == 1, "only 1 actor_device_ids is supported now"
109
+ # fmt: on
110
+ return args
111
+
112
+
113
+ def make_env(env_id, seed, num_envs, async_batch_size=1):
114
+ def thunk():
115
+ envs = envpool.make(
116
+ env_id,
117
+ env_type="gym",
118
+ num_envs=num_envs,
119
+ batch_size=async_batch_size,
120
+ seed=seed,
121
+ )
122
+ envs.num_envs = num_envs
123
+ envs.single_action_space = envs.action_space
124
+ envs.single_observation_space = envs.observation_space
125
+ envs.is_vector_env = True
126
+ return envs
127
+
128
+ return thunk
129
+
130
+
131
+ class ResidualBlock(nn.Module):
132
+ channels: int
133
+
134
+ @nn.compact
135
+ def __call__(self, x):
136
+ inputs = x
137
+ x = nn.relu(x)
138
+ x = nn.Conv(
139
+ self.channels,
140
+ kernel_size=(3, 3),
141
+ )(x)
142
+ x = nn.relu(x)
143
+ x = nn.Conv(
144
+ self.channels,
145
+ kernel_size=(3, 3),
146
+ )(x)
147
+ return x + inputs
148
+
149
+
150
+ class ConvSequence(nn.Module):
151
+ channels: int
152
+
153
+ @nn.compact
154
+ def __call__(self, x):
155
+ x = nn.Conv(
156
+ self.channels,
157
+ kernel_size=(3, 3),
158
+ )(x)
159
+ x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")
160
+ x = ResidualBlock(self.channels)(x)
161
+ x = ResidualBlock(self.channels)(x)
162
+ return x
163
+
164
+
165
+ class Network(nn.Module):
166
+ channelss: Sequence[int] = (16, 32, 32)
167
+
168
+ @nn.compact
169
+ def __call__(self, x):
170
+ x = jnp.transpose(x, (0, 2, 3, 1))
171
+ x = x / (255.0)
172
+ for channels in self.channelss:
173
+ x = ConvSequence(channels)(x)
174
+ x = nn.relu(x)
175
+ x = x.reshape((x.shape[0], -1))
176
+ x = nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
177
+ x = nn.relu(x)
178
+ return x
179
+
180
+
181
+ class Critic(nn.Module):
182
+ @nn.compact
183
+ def __call__(self, x):
184
+ return nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(x)
185
+
186
+
187
+ class Actor(nn.Module):
188
+ action_dim: int
189
+
190
+ @nn.compact
191
+ def __call__(self, x):
192
+ return nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x)
193
+
194
+
195
+ @flax.struct.dataclass
196
+ class AgentParams:
197
+ network_params: flax.core.FrozenDict
198
+ actor_params: flax.core.FrozenDict
199
+ critic_params: flax.core.FrozenDict
200
+
201
+
202
+ @partial(jax.jit, static_argnums=(3))
203
+ def get_action_and_value(
204
+ params: TrainState,
205
+ next_obs: np.ndarray,
206
+ key: jax.random.PRNGKey,
207
+ action_dim: int,
208
+ ):
209
+ next_obs = jnp.array(next_obs)
210
+ hidden = Network().apply(params.network_params, next_obs)
211
+ logits = Actor(action_dim).apply(params.actor_params, hidden)
212
+ # sample action: Gumbel-softmax trick
213
+ # see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
214
+ key, subkey = jax.random.split(key)
215
+ u = jax.random.uniform(subkey, shape=logits.shape)
216
+ action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
217
+ logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
218
+ value = Critic().apply(params.critic_params, hidden)
219
+ return next_obs, action, logprob, value.squeeze(), key
220
+
221
+
222
+ def prepare_data(
223
+ obs: list,
224
+ dones: list,
225
+ values: list,
226
+ actions: list,
227
+ logprobs: list,
228
+ env_ids: list,
229
+ rewards: list,
230
+ ):
231
+ obs = jnp.asarray(obs)
232
+ dones = jnp.asarray(dones)
233
+ values = jnp.asarray(values)
234
+ actions = jnp.asarray(actions)
235
+ logprobs = jnp.asarray(logprobs)
236
+ env_ids = jnp.asarray(env_ids)
237
+ rewards = jnp.asarray(rewards)
238
+
239
+ # TODO: in an unlikely event, one of the envs might have not stepped at all, which may results in unexpected behavior
240
+ T, B = env_ids.shape
241
+ index_ranges = jnp.arange(T * B, dtype=jnp.int32)
242
+ next_index_ranges = jnp.zeros_like(index_ranges, dtype=jnp.int32)
243
+ last_env_ids = jnp.zeros(args.local_num_envs, dtype=jnp.int32) - 1
244
+
245
+ def f(carry, x):
246
+ last_env_ids, next_index_ranges = carry
247
+ env_id, index_range = x
248
+ next_index_ranges = next_index_ranges.at[last_env_ids[env_id]].set(
249
+ jnp.where(last_env_ids[env_id] != -1, index_range, next_index_ranges[last_env_ids[env_id]])
250
+ )
251
+ last_env_ids = last_env_ids.at[env_id].set(index_range)
252
+ return (last_env_ids, next_index_ranges), None
253
+
254
+ (last_env_ids, next_index_ranges), _ = jax.lax.scan(
255
+ f,
256
+ (last_env_ids, next_index_ranges),
257
+ (env_ids.reshape(-1), index_ranges),
258
+ )
259
+
260
+ # rewards is off by one time step
261
+ rewards = rewards.reshape(-1)[next_index_ranges].reshape((args.num_steps) * args.async_update, args.async_batch_size)
262
+ advantages, returns, _, final_env_ids = compute_gae(env_ids, rewards, values, dones)
263
+ # b_inds = jnp.nonzero(final_env_ids.reshape(-1), size=(args.num_steps) * args.async_update * args.async_batch_size)[0] # useful for debugging
264
+ b_obs = obs.reshape((-1,) + obs.shape[2:])
265
+ b_actions = actions.reshape(-1)
266
+ b_logprobs = logprobs.reshape(-1)
267
+ b_advantages = advantages.reshape(-1)
268
+ b_returns = returns.reshape(-1)
269
+ return b_obs, b_actions, b_logprobs, b_advantages, b_returns
270
+
271
+
272
+ def rollout(
273
+ key: jax.random.PRNGKey,
274
+ args,
275
+ rollout_queue,
276
+ params_queue: queue.Queue,
277
+ writer,
278
+ learner_devices,
279
+ ):
280
+ envs = make_env(args.env_id, args.seed, args.local_num_envs, args.async_batch_size)()
281
+ len_actor_device_ids = len(args.actor_device_ids)
282
+ global_step = 0
283
+ # TRY NOT TO MODIFY: start the game
284
+ start_time = time.time()
285
+
286
+ # put data in the last index
287
+ episode_returns = np.zeros((args.local_num_envs,), dtype=np.float32)
288
+ returned_episode_returns = np.zeros((args.local_num_envs,), dtype=np.float32)
289
+ episode_lengths = np.zeros((args.local_num_envs,), dtype=np.float32)
290
+ returned_episode_lengths = np.zeros((args.local_num_envs,), dtype=np.float32)
291
+ envs.async_reset()
292
+
293
+ params_queue_get_time = deque(maxlen=10)
294
+ rollout_time = deque(maxlen=10)
295
+ rollout_queue_put_time = deque(maxlen=10)
296
+ actor_policy_version = 0
297
+ for update in range(1, args.num_updates + 2):
298
+ # NOTE: This is a major difference from the sync version:
299
+ # at the end of the rollout phase, the sync version will have the next observation
300
+ # ready for the value bootstrap, but the async version will not have it.
301
+ # for this reason we do `num_steps + 1`` to get the extra states for value bootstrapping.
302
+ # but note that the extra states are not used for the loss computation in the next iteration,
303
+ # while the sync version will use the extra state for the loss computation.
304
+ update_time_start = time.time()
305
+ obs = []
306
+ dones = []
307
+ actions = []
308
+ logprobs = []
309
+ values = []
310
+ env_ids = []
311
+ rewards = []
312
+ truncations = []
313
+ terminations = []
314
+ env_recv_time = 0
315
+ inference_time = 0
316
+ storage_time = 0
317
+ env_send_time = 0
318
+
319
+ # NOTE: `update != 2` is actually IMPORTANT — it allows us to start running policy collection
320
+ # concurrently with the learning process. It also ensures the actor's policy version is only 1 step
321
+ # behind the learner's policy version
322
+ params_queue_get_time_start = time.time()
323
+ if update != 2:
324
+ params = params_queue.get()
325
+ actor_policy_version += 1
326
+ params_queue_get_time.append(time.time() - params_queue_get_time_start)
327
+ writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
328
+ rollout_time_start = time.time()
329
+ for _ in range(
330
+ args.async_update, (args.num_steps + 1) * args.async_update
331
+ ): # num_steps + 1 to get the states for value bootstrapping.
332
+ env_recv_time_start = time.time()
333
+ next_obs, next_reward, next_done, info = envs.recv()
334
+ env_recv_time += time.time() - env_recv_time_start
335
+ global_step += len(next_done) * len_actor_device_ids * args.world_size
336
+ env_id = info["env_id"]
337
+
338
+ inference_time_start = time.time()
339
+ next_obs, action, logprob, value, key = get_action_and_value(params, next_obs, key, envs.single_action_space.n)
340
+ inference_time += time.time() - inference_time_start
341
+
342
+ env_send_time_start = time.time()
343
+ envs.send(np.array(action), env_id)
344
+ env_send_time += time.time() - env_send_time_start
345
+ storage_time_start = time.time()
346
+ obs.append(next_obs)
347
+ dones.append(next_done)
348
+ values.append(value)
349
+ actions.append(action)
350
+ logprobs.append(logprob)
351
+ env_ids.append(env_id)
352
+ rewards.append(next_reward)
353
+
354
+ # info["TimeLimit.truncated"] has a bug https://github.com/sail-sg/envpool/issues/239
355
+ # so we use our own truncated flag
356
+ truncated = info["elapsed_step"] >= envs.spec.config.max_episode_steps
357
+ truncations.append(truncated)
358
+ terminations.append(next_done)
359
+ episode_returns[env_id] += next_reward
360
+ returned_episode_returns[env_id] = np.where(
361
+ next_done + truncated, episode_returns[env_id], returned_episode_returns[env_id]
362
+ )
363
+ episode_returns[env_id] *= (1 - next_done) * (1 - truncated)
364
+ episode_lengths[env_id] += 1
365
+ returned_episode_lengths[env_id] = np.where(
366
+ next_done + truncated, episode_lengths[env_id], returned_episode_lengths[env_id]
367
+ )
368
+ episode_lengths[env_id] *= (1 - next_done) * (1 - truncated)
369
+ storage_time += time.time() - storage_time_start
370
+ if args.profile:
371
+ action.block_until_ready()
372
+ rollout_time.append(time.time() - rollout_time_start)
373
+ writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step)
374
+
375
+ avg_episodic_return = np.mean(returned_episode_returns)
376
+ writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
377
+ writer.add_scalar("charts/avg_episodic_length", np.mean(returned_episode_lengths), global_step)
378
+ print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}")
379
+ print("SPS:", int(global_step / (time.time() - start_time)))
380
+ writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
381
+
382
+ writer.add_scalar("stats/truncations", np.sum(truncations), global_step)
383
+ writer.add_scalar("stats/terminations", np.sum(terminations), global_step)
384
+ writer.add_scalar("stats/env_recv_time", env_recv_time, global_step)
385
+ writer.add_scalar("stats/inference_time", inference_time, global_step)
386
+ writer.add_scalar("stats/storage_time", storage_time, global_step)
387
+ writer.add_scalar("stats/env_send_time", env_send_time, global_step)
388
+
389
+ payload = (
390
+ global_step,
391
+ actor_policy_version,
392
+ update,
393
+ obs,
394
+ dones,
395
+ values,
396
+ actions,
397
+ logprobs,
398
+ env_ids,
399
+ rewards,
400
+ )
401
+ if update == 1 or not args.test_actor_learner_throughput:
402
+ rollout_queue_put_time_start = time.time()
403
+ rollout_queue.put(payload)
404
+ rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start)
405
+ writer.add_scalar("stats/rollout_queue_put_time", np.mean(rollout_queue_put_time), global_step)
406
+
407
+ writer.add_scalar(
408
+ "charts/SPS_update",
409
+ int(
410
+ args.local_num_envs
411
+ * args.num_steps
412
+ * len_actor_device_ids
413
+ * args.world_size
414
+ / (time.time() - update_time_start)
415
+ ),
416
+ global_step,
417
+ )
418
+
419
+
420
+ @partial(jax.jit, static_argnums=(3))
421
+ def get_action_and_value2(
422
+ params: flax.core.FrozenDict,
423
+ x: np.ndarray,
424
+ action: np.ndarray,
425
+ action_dim: int,
426
+ ):
427
+ hidden = Network().apply(params.network_params, x)
428
+ logits = Actor(action_dim).apply(params.actor_params, hidden)
429
+ logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
430
+ logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
431
+ logits = logits.clip(min=jnp.finfo(logits.dtype).min)
432
+ p_log_p = logits * jax.nn.softmax(logits)
433
+ entropy = -p_log_p.sum(-1)
434
+ value = Critic().apply(params.critic_params, hidden).squeeze()
435
+ return logprob, entropy, value
436
+
437
+
438
+ @jax.jit
439
+ def compute_gae(
440
+ env_ids: np.ndarray,
441
+ rewards: np.ndarray,
442
+ values: np.ndarray,
443
+ dones: np.ndarray,
444
+ ):
445
+ dones = jnp.asarray(dones)
446
+ values = jnp.asarray(values)
447
+ env_ids = jnp.asarray(env_ids)
448
+ rewards = jnp.asarray(rewards)
449
+
450
+ _, B = env_ids.shape
451
+ final_env_id_checked = jnp.zeros(args.local_num_envs, jnp.int32) - 1
452
+ final_env_ids = jnp.zeros(B, jnp.int32)
453
+ advantages = jnp.zeros(B)
454
+ lastgaelam = jnp.zeros(args.local_num_envs)
455
+ lastdones = jnp.zeros(args.local_num_envs) + 1
456
+ lastvalues = jnp.zeros(args.local_num_envs)
457
+
458
+ def compute_gae_once(carry, x):
459
+ lastvalues, lastdones, advantages, lastgaelam, final_env_ids, final_env_id_checked = carry
460
+ (
461
+ done,
462
+ value,
463
+ eid,
464
+ reward,
465
+ ) = x
466
+ nextnonterminal = 1.0 - lastdones[eid]
467
+ nextvalues = lastvalues[eid]
468
+ delta = jnp.where(final_env_id_checked[eid] == -1, 0, reward + args.gamma * nextvalues * nextnonterminal - value)
469
+ advantages = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam[eid]
470
+ final_env_ids = jnp.where(final_env_id_checked[eid] == 1, 1, 0)
471
+ final_env_id_checked = final_env_id_checked.at[eid].set(
472
+ jnp.where(final_env_id_checked[eid] == -1, 1, final_env_id_checked[eid])
473
+ )
474
+
475
+ # the last_ variables keeps track of the actual `num_steps`
476
+ lastgaelam = lastgaelam.at[eid].set(advantages)
477
+ lastdones = lastdones.at[eid].set(done)
478
+ lastvalues = lastvalues.at[eid].set(value)
479
+ return (lastvalues, lastdones, advantages, lastgaelam, final_env_ids, final_env_id_checked), (
480
+ advantages,
481
+ final_env_ids,
482
+ )
483
+
484
+ (_, _, _, _, final_env_ids, final_env_id_checked), (advantages, final_env_ids) = jax.lax.scan(
485
+ compute_gae_once,
486
+ (
487
+ lastvalues,
488
+ lastdones,
489
+ advantages,
490
+ lastgaelam,
491
+ final_env_ids,
492
+ final_env_id_checked,
493
+ ),
494
+ (
495
+ dones,
496
+ values,
497
+ env_ids,
498
+ rewards,
499
+ ),
500
+ reverse=True,
501
+ )
502
+ return advantages, advantages + values, final_env_id_checked, final_env_ids
503
+
504
+
505
+ def ppo_loss(params, x, a, logp, mb_advantages, mb_returns, action_dim):
506
+ newlogprob, entropy, newvalue = get_action_and_value2(params, x, a, action_dim)
507
+ logratio = newlogprob - logp
508
+ ratio = jnp.exp(logratio)
509
+ approx_kl = ((ratio - 1) - logratio).mean()
510
+
511
+ if args.norm_adv:
512
+ mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
513
+
514
+ # Policy loss
515
+ pg_loss1 = -mb_advantages * ratio
516
+ pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
517
+ pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean()
518
+
519
+ # Value loss
520
+ v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean()
521
+
522
+ entropy_loss = entropy.mean()
523
+ loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
524
+ return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
525
+
526
+
527
+ @partial(jax.jit, static_argnums=(6))
528
+ def single_device_update(
529
+ agent_state: TrainState,
530
+ b_obs,
531
+ b_actions,
532
+ b_logprobs,
533
+ b_advantages,
534
+ b_returns,
535
+ action_dim,
536
+ key: jax.random.PRNGKey,
537
+ ):
538
+ ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
539
+
540
+ def update_epoch(carry, _):
541
+ agent_state, key = carry
542
+ key, subkey = jax.random.split(key)
543
+
544
+ # taken from: https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py
545
+ def convert_data(x: jnp.ndarray):
546
+ x = jax.random.permutation(subkey, x)
547
+ x = jnp.reshape(x, (args.num_minibatches, -1) + x.shape[1:])
548
+ return x
549
+
550
+ def update_minibatch(agent_state, minibatch):
551
+ mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns = minibatch
552
+ (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
553
+ agent_state.params,
554
+ mb_obs,
555
+ mb_actions,
556
+ mb_logprobs,
557
+ mb_advantages,
558
+ mb_returns,
559
+ action_dim,
560
+ )
561
+ grads = jax.lax.pmean(grads, axis_name="local_devices")
562
+ agent_state = agent_state.apply_gradients(grads=grads)
563
+ return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads)
564
+
565
+ agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan(
566
+ update_minibatch,
567
+ agent_state,
568
+ (
569
+ convert_data(b_obs),
570
+ convert_data(b_actions),
571
+ convert_data(b_logprobs),
572
+ convert_data(b_advantages),
573
+ convert_data(b_returns),
574
+ ),
575
+ )
576
+ return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads)
577
+
578
+ (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, _) = jax.lax.scan(
579
+ update_epoch, (agent_state, key), (), length=args.update_epochs
580
+ )
581
+ return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key
582
+
583
+
584
+ if __name__ == "__main__":
585
+ args = parse_args()
586
+ if args.distributed:
587
+ jax.distributed.initialize(
588
+ local_device_ids=range(len(args.learner_device_ids) + len(args.actor_device_ids)),
589
+ )
590
+ print(list(range(len(args.learner_device_ids) + len(args.actor_device_ids))))
591
+
592
+ args.world_size = jax.process_count()
593
+ args.local_rank = jax.process_index()
594
+ args.num_envs = args.local_num_envs * args.world_size
595
+ args.batch_size = args.local_batch_size * args.world_size
596
+ args.minibatch_size = args.local_minibatch_size * args.world_size
597
+ args.num_updates = args.total_timesteps // (args.local_batch_size * args.world_size)
598
+ args.async_update = int(args.local_num_envs / args.async_batch_size)
599
+ local_devices = jax.local_devices()
600
+ global_devices = jax.devices()
601
+ learner_devices = [local_devices[d_id] for d_id in args.learner_device_ids]
602
+ actor_devices = [local_devices[d_id] for d_id in args.actor_device_ids]
603
+ global_learner_decices = [
604
+ global_devices[d_id + process_index * len(local_devices)]
605
+ for process_index in range(args.world_size)
606
+ for d_id in args.learner_device_ids
607
+ ]
608
+ print("global_learner_decices", global_learner_decices)
609
+ args.global_learner_decices = [str(item) for item in global_learner_decices]
610
+ args.actor_devices = [str(item) for item in actor_devices]
611
+ args.learner_devices = [str(item) for item in learner_devices]
612
+
613
+ run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{uuid.uuid4()}"
614
+ if args.track and args.local_rank == 0:
615
+ import wandb
616
+
617
+ wandb.init(
618
+ project=args.wandb_project_name,
619
+ entity=args.wandb_entity,
620
+ sync_tensorboard=True,
621
+ config=vars(args),
622
+ name=run_name,
623
+ monitor_gym=True,
624
+ save_code=True,
625
+ )
626
+ writer = SummaryWriter(f"runs/{run_name}")
627
+ writer.add_text(
628
+ "hyperparameters",
629
+ "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
630
+ )
631
+
632
+ # TRY NOT TO MODIFY: seeding
633
+ random.seed(args.seed)
634
+ np.random.seed(args.seed)
635
+ key = jax.random.PRNGKey(args.seed)
636
+ key, network_key, actor_key, critic_key = jax.random.split(key, 4)
637
+
638
+ # env setup
639
+ envs = make_env(args.env_id, args.seed, args.local_num_envs, args.async_batch_size)()
640
+ assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
641
+
642
+ def linear_schedule(count):
643
+ # anneal learning rate linearly after one training iteration which contains
644
+ # (args.num_minibatches * args.update_epochs) gradient updates
645
+ frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates
646
+ return args.learning_rate * frac
647
+
648
+ network = Network()
649
+ actor = Actor(action_dim=envs.single_action_space.n)
650
+ critic = Critic()
651
+ network_params = network.init(network_key, np.array([envs.single_observation_space.sample()]))
652
+ agent_state = TrainState.create(
653
+ apply_fn=None,
654
+ params=AgentParams(
655
+ network_params,
656
+ actor.init(actor_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))),
657
+ critic.init(critic_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))),
658
+ ),
659
+ tx=optax.chain(
660
+ optax.clip_by_global_norm(args.max_grad_norm),
661
+ optax.inject_hyperparams(optax.adam)(
662
+ learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5
663
+ ),
664
+ ),
665
+ )
666
+ agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices)
667
+
668
+ multi_device_update = jax.pmap(
669
+ single_device_update,
670
+ axis_name="local_devices",
671
+ devices=global_learner_decices,
672
+ in_axes=(0, 0, 0, 0, 0, 0, None, None),
673
+ out_axes=(0, 0, 0, 0, 0, 0, None),
674
+ static_broadcasted_argnums=(6),
675
+ )
676
+
677
+ rollout_queue = queue.Queue(maxsize=1)
678
+ params_queues = []
679
+ for d_idx, d_id in enumerate(args.actor_device_ids):
680
+ params_queue = queue.Queue(maxsize=1)
681
+ params_queue.put(jax.device_put(flax.jax_utils.unreplicate(agent_state.params), local_devices[d_id]))
682
+ threading.Thread(
683
+ target=rollout,
684
+ args=(
685
+ jax.device_put(key, local_devices[d_id]),
686
+ args,
687
+ rollout_queue,
688
+ params_queue,
689
+ writer,
690
+ learner_devices,
691
+ ),
692
+ ).start()
693
+ params_queues.append(params_queue)
694
+
695
+ rollout_queue_get_time = deque(maxlen=10)
696
+ data_transfer_time = deque(maxlen=10)
697
+ learner_policy_version = 0
698
+ prepare_data = jax.jit(prepare_data, device=learner_devices[0])
699
+ while True:
700
+ learner_policy_version += 1
701
+ if learner_policy_version == 1 or not args.test_actor_learner_throughput:
702
+ rollout_queue_get_time_start = time.time()
703
+ (
704
+ global_step,
705
+ actor_policy_version,
706
+ update,
707
+ obs,
708
+ dones,
709
+ values,
710
+ actions,
711
+ logprobs,
712
+ env_ids,
713
+ rewards,
714
+ ) = rollout_queue.get()
715
+ rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
716
+ writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
717
+
718
+ data_transfer_time_start = time.time()
719
+ b_obs, b_actions, b_logprobs, b_advantages, b_returns = prepare_data(
720
+ obs,
721
+ dones,
722
+ values,
723
+ actions,
724
+ logprobs,
725
+ env_ids,
726
+ rewards,
727
+ )
728
+ b_obs = jnp.array_split(b_obs, len(learner_devices))
729
+ b_actions = jnp.array_split(b_actions, len(learner_devices))
730
+ b_logprobs = jnp.array_split(b_logprobs, len(learner_devices))
731
+ b_advantages = jnp.array_split(b_advantages, len(learner_devices))
732
+ b_returns = jnp.array_split(b_returns, len(learner_devices))
733
+ data_transfer_time.append(time.time() - data_transfer_time_start)
734
+ writer.add_scalar("stats/data_transfer_time", np.mean(data_transfer_time), global_step)
735
+
736
+ training_time_start = time.time()
737
+ (agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key) = multi_device_update(
738
+ agent_state,
739
+ jax.device_put_sharded(b_obs, learner_devices),
740
+ jax.device_put_sharded(b_actions, learner_devices),
741
+ jax.device_put_sharded(b_logprobs, learner_devices),
742
+ jax.device_put_sharded(b_advantages, learner_devices),
743
+ jax.device_put_sharded(b_returns, learner_devices),
744
+ envs.single_action_space.n,
745
+ key,
746
+ )
747
+ if learner_policy_version == 1 or not args.test_actor_learner_throughput:
748
+ for d_idx, d_id in enumerate(args.actor_device_ids):
749
+ params_queues[d_idx].put(jax.device_put(flax.jax_utils.unreplicate(agent_state.params), local_devices[d_id]))
750
+ if args.profile:
751
+ v_loss[-1, -1, -1].block_until_ready()
752
+ writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
753
+ writer.add_scalar("stats/rollout_queue_size", rollout_queue.qsize(), global_step)
754
+ writer.add_scalar("stats/params_queue_size", params_queue.qsize(), global_step)
755
+ print(
756
+ global_step,
757
+ f"actor_policy_version={actor_policy_version}, actor_update={update}, learner_policy_version={learner_policy_version}, training time: {time.time() - training_time_start}s",
758
+ )
759
+
760
+ # TRY NOT TO MODIFY: record rewards for plotting purposes
761
+ writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"][0].item(), global_step)
762
+ writer.add_scalar("losses/value_loss", v_loss[-1, -1, -1].item(), global_step)
763
+ writer.add_scalar("losses/policy_loss", pg_loss[-1, -1, -1].item(), global_step)
764
+ writer.add_scalar("losses/entropy", entropy_loss[-1, -1, -1].item(), global_step)
765
+ writer.add_scalar("losses/approx_kl", approx_kl[-1, -1, -1].item(), global_step)
766
+ writer.add_scalar("losses/loss", loss[-1, -1, -1].item(), global_step)
767
+ if update >= args.num_updates:
768
+ break
769
+
770
+ if args.save_model and args.local_rank == 0:
771
+ if args.distributed:
772
+ jax.distributed.shutdown()
773
+ agent_state = flax.jax_utils.unreplicate(agent_state)
774
+ model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
775
+ with open(model_path, "wb") as f:
776
+ f.write(
777
+ flax.serialization.to_bytes(
778
+ [
779
+ vars(args),
780
+ [
781
+ agent_state.params.network_params,
782
+ agent_state.params.actor_params,
783
+ agent_state.params.critic_params,
784
+ ],
785
+ ]
786
+ )
787
+ )
788
+ print(f"model saved to {model_path}")
789
+ from cleanrl_utils.evals.cleanba_ppo_envpool_procgen_eval import evaluate
790
+
791
+ episodic_returns = evaluate(
792
+ model_path,
793
+ make_env,
794
+ args.env_id,
795
+ eval_episodes=10,
796
+ run_name=f"{run_name}-eval",
797
+ Model=(Network, Actor, Critic),
798
+ )
799
+ for idx, episodic_return in enumerate(episodic_returns):
800
+ writer.add_scalar("eval/episodic_return", episodic_return, idx)
801
+
802
+ if args.upload_model:
803
+ from cleanrl_utils.huggingface import push_to_hub
804
+
805
+ repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
806
+ repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
807
+ push_to_hub(
808
+ args,
809
+ episodic_returns,
810
+ repo_id,
811
+ "PPO",
812
+ f"runs/{run_name}",
813
+ f"videos/{run_name}-eval",
814
+ extra_dependencies=["jax", "envpool", "atari"],
815
+ )
816
+
817
+ envs.close()
818
+ writer.close()
events.out.tfevents.1677090535.ip-26-0-140-114 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f378d0d7a580a2d62d9b03027db2440ba67ffc080313c1c30901c032155e1cd0
3
+ size 2354021
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "cleanba"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Costa Huang <costa.huang@outlook.com>"]
6
+ readme = "README.md"
7
+ packages = [
8
+ { include = "cleanba" },
9
+ { include = "cleanrl_utils" },
10
+ ]
11
+
12
+ [tool.poetry.dependencies]
13
+ python = "^3.8"
14
+ tensorboard = "^2.12.0"
15
+ envpool = "^0.8.1"
16
+ jax = "0.3.25"
17
+ flax = "0.6.0"
18
+ optax = "0.1.3"
19
+ huggingface-hub = "^0.12.0"
20
+ jaxlib = "0.3.25"
21
+ wandb = "^0.13.10"
22
+ tensorboardx = "^2.5.1"
23
+ chex = "0.1.5"
24
+ gym = "0.23.1"
25
+ opencv-python = "^4.7.0.68"
26
+ moviepy = "^1.0.3"
27
+
28
+
29
+ [tool.poetry.group.dev.dependencies]
30
+ pre-commit = "^3.0.4"
31
+
32
+ [build-system]
33
+ requires = ["poetry-core"]
34
+ build-backend = "poetry.core.masonry.api"
replay.mp4 ADDED
Binary file (22.7 kB). View file
 
videos/MinerHard-v0__cleanba_ppo_envpool_procgen__1__0e328367-954d-4bde-9f1f-1d8e48975245-eval/0.mp4 ADDED
Binary file (22.7 kB). View file