Michele Milesi commited on
Commit
6b39341
•
1 Parent(s): 12794d0

feat: added dv3

Browse files
agent-dreamer_v3.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ import gymnasium as gym
5
+ import torch
6
+ from lightning import Fabric
7
+ from omegaconf import OmegaConf
8
+ from sheeprl.algos.dreamer_v3.agent import build_agent
9
+ from sheeprl.utils.env import make_env
10
+ from sheeprl.utils.utils import dotdict
11
+
12
+ """This is an example agent based on SheepRL.
13
+
14
+ Usage:
15
+ cd sheeprl
16
+ diambra run python agent-dreamer_v3.py --cfg_path "./fake-logs/runs/dreamer_v3/doapp/fake-experiment/version_0/config.yaml" --checkpoint_path "./fake-logs/runs/dreamer_v3/doapp/fake-experiment/version_0/checkpoint/ckpt_1024_0.ckpt"
17
+ """
18
+
19
+
20
+ def main(cfg_path: str, checkpoint_path: str, test=False):
21
+ # Read the cfg file
22
+ cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True))
23
+ print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4))
24
+
25
+ # Override configs for evaluation
26
+ # You do not need to capture the video since you are submitting the agent and the video is recorded by DIAMBRA
27
+ cfg.env.capture_video = False
28
+
29
+ # Instantiate Fabric
30
+ # You must use the same precision and plugins used for training.
31
+ precision = getattr(cfg.fabric, "precision", None)
32
+ plugins = getattr(cfg.fabric, "plugins", None)
33
+ fabric = Fabric(
34
+ accelerator="auto",
35
+ devices=1,
36
+ num_nodes=1,
37
+ precision=precision,
38
+ plugins=plugins,
39
+ strategy="auto",
40
+ )
41
+
42
+ # Create Environment
43
+ env = make_env(cfg, 0, 0)()
44
+ observation_space = env.observation_space
45
+ is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
46
+ actions_dim = tuple(
47
+ env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]
48
+ )
49
+ cnn_keys = cfg.algo.cnn_keys.encoder
50
+ mlp_keys = cfg.algo.mlp_keys.encoder
51
+ obs_keys = mlp_keys + cnn_keys
52
+
53
+ # Load the trained agent
54
+ state = fabric.load(checkpoint_path)
55
+ # You need to retrieve only the player
56
+ # Check for each algorithm what models the `build_agent()` function returns
57
+ # (placed in the `agent.py` file of the algorithm), and which arguments it needs.
58
+ # Check also which are the keys of the checkpoint: if the `build_agent()` parameter
59
+ # is called `model_state`, then you retrieve the model state with `state["model"]`.
60
+ agent = build_agent(
61
+ fabric=fabric,
62
+ actions_dim=actions_dim,
63
+ is_continuous=False,
64
+ cfg=cfg,
65
+ obs_space=observation_space,
66
+ world_model_state=state["world_model"],
67
+ actor_state=state["actor"],
68
+ critic_state=state["critic"],
69
+ target_critic_state=state["target_critic"],
70
+ )[-1]
71
+ agent.eval()
72
+
73
+ # Print policy network architecture
74
+ print("Policy architecture:")
75
+ print(agent)
76
+
77
+ o, info = env.reset()
78
+
79
+ while True:
80
+ # Convert numpy observations into torch observations and normalize image observations
81
+ # Every algorithm has its own way to do it, check in the test function of the algorithm
82
+ # which is the correct way to it.
83
+ # Check the `test()` function called in the `evaluate.py` file of the algorithm.
84
+ obs = {}
85
+ for k in obs_keys:
86
+ obs[k] = (
87
+ torch.from_numpy(o[k]).to(fabric.device).view(1, 1, *o[k].shape).float()
88
+ )
89
+ if k in cnn_keys:
90
+ obs[k] = obs[k] / 255 - 0.5
91
+
92
+ # Select actions, the agent returns a one-hot categorical or
93
+ # more one-hot categorical distributions for muli-discrete actions space
94
+ actions = agent.get_actions(obs, greedy=False)
95
+ # Convert actions from one-hot categorical to categorial
96
+ actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
97
+
98
+ o, _, terminated, truncated, info = env.step(
99
+ actions.cpu().numpy().reshape(env.action_space.shape)
100
+ )
101
+
102
+ if terminated or truncated:
103
+ o, info = env.reset()
104
+ if info["env_done"] or test is True:
105
+ break
106
+
107
+ # Close the environment
108
+ env.close()
109
+
110
+ # Return success
111
+ return 0
112
+
113
+
114
+ if __name__ == "__main__":
115
+ parser = argparse.ArgumentParser()
116
+ parser.add_argument(
117
+ "--cfg_path", type=str, required=True, help="Configuration file"
118
+ )
119
+ parser.add_argument(
120
+ "--checkpoint_path", type=str, default="model", help="Model checkpoint"
121
+ )
122
+ parser.add_argument("--test", action="store_true", help="Test mode")
123
+ opt = parser.parse_args()
124
+ print(opt)
125
+
126
+ main(opt.cfg_path, opt.checkpoint_path, opt.test)
results/dreamer_v3/ckpt_1024_0.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f9a9b2bccb05f94a374a010446b009febd8ce8ae63105aec1455cf99c5b4cdc
3
+ size 389012
results/dreamer_v3/config.yaml ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_threads: 1
2
+ float32_matmul_precision: high
3
+ dry_run: false
4
+ seed: 42
5
+ torch_use_deterministic_algorithms: false
6
+ torch_backends_cudnn_benchmark: true
7
+ torch_backends_cudnn_deterministic: false
8
+ cublas_workspace_config: null
9
+ exp_name: dreamer_v3_doapp
10
+ run_name: 2024-04-16_17-34-17_dreamer_v3_doapp_42
11
+ root_dir: dreamer_v3/doapp
12
+ algo:
13
+ name: dreamer_v3
14
+ total_steps: 1024
15
+ per_rank_batch_size: 2
16
+ run_test: false
17
+ cnn_keys:
18
+ encoder:
19
+ - frame
20
+ decoder:
21
+ - frame
22
+ mlp_keys:
23
+ encoder:
24
+ - own_character
25
+ - own_health
26
+ - own_side
27
+ - own_wins
28
+ - opp_character
29
+ - opp_health
30
+ - opp_side
31
+ - opp_wins
32
+ - stage
33
+ - timer
34
+ - action
35
+ decoder:
36
+ - own_character
37
+ - own_health
38
+ - own_side
39
+ - own_wins
40
+ - opp_character
41
+ - opp_health
42
+ - opp_side
43
+ - opp_wins
44
+ - stage
45
+ - timer
46
+ - action
47
+ world_model:
48
+ optimizer:
49
+ _target_: torch.optim.Adam
50
+ lr: 0.0001
51
+ eps: 1.0e-08
52
+ weight_decay: 0
53
+ betas:
54
+ - 0.9
55
+ - 0.999
56
+ discrete_size: 4
57
+ stochastic_size: 4
58
+ kl_dynamic: 0.5
59
+ kl_representation: 0.1
60
+ kl_free_nats: 1.0
61
+ kl_regularizer: 1.0
62
+ continue_scale_factor: 1.0
63
+ clip_gradients: 1000.0
64
+ decoupled_rssm: false
65
+ learnable_initial_recurrent_state: true
66
+ encoder:
67
+ cnn_channels_multiplier: 2
68
+ cnn_act: torch.nn.SiLU
69
+ dense_act: torch.nn.SiLU
70
+ mlp_layers: 1
71
+ cnn_layer_norm:
72
+ cls: sheeprl.models.models.LayerNormChannelLast
73
+ kw:
74
+ eps: 0.001
75
+ mlp_layer_norm:
76
+ cls: sheeprl.models.models.LayerNorm
77
+ kw:
78
+ eps: 0.001
79
+ dense_units: 8
80
+ recurrent_model:
81
+ recurrent_state_size: 8
82
+ layer_norm:
83
+ cls: sheeprl.models.models.LayerNorm
84
+ kw:
85
+ eps: 0.001
86
+ dense_units: 8
87
+ transition_model:
88
+ hidden_size: 8
89
+ dense_act: torch.nn.SiLU
90
+ layer_norm:
91
+ cls: sheeprl.models.models.LayerNorm
92
+ kw:
93
+ eps: 0.001
94
+ representation_model:
95
+ hidden_size: 8
96
+ dense_act: torch.nn.SiLU
97
+ layer_norm:
98
+ cls: sheeprl.models.models.LayerNorm
99
+ kw:
100
+ eps: 0.001
101
+ observation_model:
102
+ cnn_channels_multiplier: 2
103
+ cnn_act: torch.nn.SiLU
104
+ dense_act: torch.nn.SiLU
105
+ mlp_layers: 1
106
+ cnn_layer_norm:
107
+ cls: sheeprl.models.models.LayerNormChannelLast
108
+ kw:
109
+ eps: 0.001
110
+ mlp_layer_norm:
111
+ cls: sheeprl.models.models.LayerNorm
112
+ kw:
113
+ eps: 0.001
114
+ dense_units: 8
115
+ reward_model:
116
+ dense_act: torch.nn.SiLU
117
+ mlp_layers: 1
118
+ layer_norm:
119
+ cls: sheeprl.models.models.LayerNorm
120
+ kw:
121
+ eps: 0.001
122
+ dense_units: 8
123
+ bins: 255
124
+ discount_model:
125
+ learnable: true
126
+ dense_act: torch.nn.SiLU
127
+ mlp_layers: 1
128
+ layer_norm:
129
+ cls: sheeprl.models.models.LayerNorm
130
+ kw:
131
+ eps: 0.001
132
+ dense_units: 8
133
+ actor:
134
+ optimizer:
135
+ _target_: torch.optim.Adam
136
+ lr: 8.0e-05
137
+ eps: 1.0e-05
138
+ weight_decay: 0
139
+ betas:
140
+ - 0.9
141
+ - 0.999
142
+ cls: sheeprl.algos.dreamer_v3.agent.Actor
143
+ ent_coef: 0.0003
144
+ min_std: 0.1
145
+ max_std: 1.0
146
+ init_std: 2.0
147
+ dense_act: torch.nn.SiLU
148
+ mlp_layers: 1
149
+ layer_norm:
150
+ cls: sheeprl.models.models.LayerNorm
151
+ kw:
152
+ eps: 0.001
153
+ dense_units: 8
154
+ clip_gradients: 100.0
155
+ unimix: 0.01
156
+ action_clip: 1.0
157
+ moments:
158
+ decay: 0.99
159
+ max: 1.0
160
+ percentile:
161
+ low: 0.05
162
+ high: 0.95
163
+ critic:
164
+ optimizer:
165
+ _target_: torch.optim.Adam
166
+ lr: 8.0e-05
167
+ eps: 1.0e-05
168
+ weight_decay: 0
169
+ betas:
170
+ - 0.9
171
+ - 0.999
172
+ dense_act: torch.nn.SiLU
173
+ mlp_layers: 1
174
+ layer_norm:
175
+ cls: sheeprl.models.models.LayerNorm
176
+ kw:
177
+ eps: 0.001
178
+ dense_units: 8
179
+ per_rank_target_network_update_freq: 1
180
+ tau: 0.02
181
+ bins: 255
182
+ clip_gradients: 100.0
183
+ gamma: 0.996996996996997
184
+ lmbda: 0.95
185
+ horizon: 15
186
+ replay_ratio: 0.0625
187
+ learning_starts: 1024
188
+ per_rank_pretrain_steps: 0
189
+ per_rank_sequence_length: 64
190
+ cnn_layer_norm:
191
+ cls: sheeprl.models.models.LayerNormChannelLast
192
+ kw:
193
+ eps: 0.001
194
+ mlp_layer_norm:
195
+ cls: sheeprl.models.models.LayerNorm
196
+ kw:
197
+ eps: 0.001
198
+ dense_units: 8
199
+ mlp_layers: 1
200
+ dense_act: torch.nn.SiLU
201
+ cnn_act: torch.nn.SiLU
202
+ unimix: 0.01
203
+ hafner_initialization: true
204
+ player:
205
+ discrete_size: 4
206
+ buffer:
207
+ size: 1024
208
+ memmap: true
209
+ validate_args: false
210
+ from_numpy: false
211
+ checkpoint: true
212
+ checkpoint:
213
+ every: 10000
214
+ resume_from: null
215
+ save_last: true
216
+ keep_last: 5
217
+ distribution:
218
+ validate_args: false
219
+ type: auto
220
+ env:
221
+ id: doapp
222
+ num_envs: 1
223
+ frame_stack: -1
224
+ sync_env: true
225
+ screen_size: 64
226
+ action_repeat: 1
227
+ grayscale: false
228
+ clip_rewards: false
229
+ capture_video: true
230
+ frame_stack_dilation: 1
231
+ max_episode_steps: null
232
+ reward_as_observation: false
233
+ wrapper:
234
+ _target_: sheeprl.envs.diambra.DiambraWrapper
235
+ id: doapp
236
+ action_space: DISCRETE
237
+ screen_size: 64
238
+ grayscale: false
239
+ repeat_action: 1
240
+ rank: null
241
+ log_level: 0
242
+ increase_performance: true
243
+ diambra_settings:
244
+ role: P1
245
+ step_ratio: 6
246
+ difficulty: 4
247
+ continue_game: 0.0
248
+ show_final: false
249
+ outfits: 2
250
+ splash_screen: false
251
+ diambra_wrappers:
252
+ stack_actions: 1
253
+ no_op_max: 0
254
+ no_attack_buttons_combinations: false
255
+ add_last_action: true
256
+ scale: false
257
+ exclude_image_scaling: false
258
+ process_discrete_binary: false
259
+ role_relative: true
260
+ fabric:
261
+ _target_: lightning.fabric.Fabric
262
+ devices: 1
263
+ num_nodes: 1
264
+ strategy: auto
265
+ accelerator: cpu
266
+ precision: 32-true
267
+ callbacks:
268
+ - _target_: sheeprl.utils.callback.CheckpointCallback
269
+ keep_last: 5
270
+ metric:
271
+ log_every: 5000
272
+ disable_timer: false
273
+ log_level: 1
274
+ sync_on_compute: false
275
+ aggregator:
276
+ _target_: sheeprl.utils.metric.MetricAggregator
277
+ raise_on_missing: false
278
+ metrics:
279
+ Rewards/rew_avg:
280
+ _target_: torchmetrics.MeanMetric
281
+ sync_on_compute: false
282
+ Game/ep_len_avg:
283
+ _target_: torchmetrics.MeanMetric
284
+ sync_on_compute: false
285
+ Loss/world_model_loss:
286
+ _target_: torchmetrics.MeanMetric
287
+ sync_on_compute: false
288
+ Loss/value_loss:
289
+ _target_: torchmetrics.MeanMetric
290
+ sync_on_compute: false
291
+ Loss/policy_loss:
292
+ _target_: torchmetrics.MeanMetric
293
+ sync_on_compute: false
294
+ Loss/observation_loss:
295
+ _target_: torchmetrics.MeanMetric
296
+ sync_on_compute: false
297
+ Loss/reward_loss:
298
+ _target_: torchmetrics.MeanMetric
299
+ sync_on_compute: false
300
+ Loss/state_loss:
301
+ _target_: torchmetrics.MeanMetric
302
+ sync_on_compute: false
303
+ Loss/continue_loss:
304
+ _target_: torchmetrics.MeanMetric
305
+ sync_on_compute: false
306
+ State/kl:
307
+ _target_: torchmetrics.MeanMetric
308
+ sync_on_compute: false
309
+ State/post_entropy:
310
+ _target_: torchmetrics.MeanMetric
311
+ sync_on_compute: false
312
+ State/prior_entropy:
313
+ _target_: torchmetrics.MeanMetric
314
+ sync_on_compute: false
315
+ Grads/world_model:
316
+ _target_: torchmetrics.MeanMetric
317
+ sync_on_compute: false
318
+ Grads/actor:
319
+ _target_: torchmetrics.MeanMetric
320
+ sync_on_compute: false
321
+ Grads/critic:
322
+ _target_: torchmetrics.MeanMetric
323
+ sync_on_compute: false
324
+ logger:
325
+ _target_: lightning.fabric.loggers.TensorBoardLogger
326
+ name: 2024-04-16_17-34-17_dreamer_v3_doapp_42
327
+ root_dir: logs/runs/dreamer_v3/doapp
328
+ version: null
329
+ default_hp_metric: true
330
+ prefix: ''
331
+ sub_dir: null
332
+ model_manager:
333
+ disabled: true
334
+ models:
335
+ world_model:
336
+ model_name: dreamer_v3_doapp_world_model
337
+ description: DreamerV3 World Model used in doapp Environment
338
+ tags: {}
339
+ actor:
340
+ model_name: dreamer_v3_doapp_actor
341
+ description: DreamerV3 Actor used in doapp Environment
342
+ tags: {}
343
+ critic:
344
+ model_name: dreamer_v3_doapp_critic
345
+ description: DreamerV3 Critic used in doapp Environment
346
+ tags: {}
347
+ target_critic:
348
+ model_name: dreamer_v3_doapp_target_critic
349
+ description: DreamerV3 Target Critic used in doapp Environment
350
+ tags: {}
351
+ moments:
352
+ model_name: dreamer_v3_doapp_moments
353
+ description: DreamerV3 Moments used in doapp Environment
354
+ tags: {}
ckpt_1024_0.ckpt → results/ppo/ckpt_1024_0.ckpt RENAMED
File without changes
config.yaml → results/ppo/config.yaml RENAMED
File without changes