michele-milesi commited on
Commit
2cbbf46
1 Parent(s): 1dd30c9

Initial Commit

Browse files
Files changed (3) hide show
  1. agent.py +116 -0
  2. ckpt_1024_0.ckpt +3 -0
  3. config.yaml +164 -0
agent.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ import gymnasium as gym
5
+ import torch
6
+ from lightning import Fabric
7
+ from omegaconf import OmegaConf
8
+ from sheeprl.algos.ppo.agent import build_agent
9
+ from sheeprl.utils.env import make_env
10
+ from sheeprl.utils.utils import dotdict
11
+
12
+ """This is an example agent based on SheepRL.
13
+
14
+ Usage:
15
+ diambra run python sheeprl/agent.py --cfg_path "./fake-logs/runs/ppo/doapp/fake-experiment/version_0/config.yaml" --checkpoint_path "./fake-logs/runs/ppo/doapp/fake-experiment/version_0/checkpoint/ckpt_1024_0.ckpt"
16
+ """
17
+
18
+
19
+ def main(cfg_path: str, checkpoint_path: str, test=False):
20
+ # Read the cfg file
21
+ cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True))
22
+ print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4))
23
+
24
+ # Override configs for evaluation
25
+ if not test:
26
+ cfg.env.capture_video = False
27
+ cfg.env.num_envs = 1
28
+
29
+ # Instantiate Fabric
30
+ precision = getattr(cfg.fabric, "precision", None)
31
+ plugins = getattr(cfg.fabric, "plugins", None)
32
+ fabric = Fabric(
33
+ accelerator="cpu",
34
+ devices=1,
35
+ num_nodes=1,
36
+ precision=precision,
37
+ plugins=plugins,
38
+ strategy="auto",
39
+ )
40
+
41
+ # Create Environment
42
+ env = make_env(cfg, 0, 0)()
43
+ observation_space = env.observation_space
44
+ is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
45
+ actions_dim = tuple(
46
+ env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]
47
+ )
48
+ cnn_keys = cfg.algo.cnn_keys.encoder
49
+ mlp_keys = cfg.algo.mlp_keys.encoder
50
+ obs_keys = mlp_keys + cnn_keys
51
+
52
+ # Load the trained agent
53
+ state = fabric.load(checkpoint_path)
54
+ # You need to retrieve only the player
55
+ agent = build_agent(
56
+ fabric=fabric,
57
+ actions_dim=actions_dim,
58
+ is_continuous=False,
59
+ cfg=cfg,
60
+ obs_space=observation_space,
61
+ agent_state=state["agent"],
62
+ )[-1]
63
+ agent.eval()
64
+
65
+ # Print policy network architecture
66
+ print("Policy architecture:")
67
+ print(agent)
68
+
69
+ o, info = env.reset()
70
+
71
+ while True:
72
+ # Convert numpy observations into torch observations and normalize image observations
73
+ obs = {}
74
+ for k in o.keys():
75
+ if k in obs_keys:
76
+ torch_obs = torch.from_numpy(o[k].copy()).to(fabric.device).unsqueeze(0)
77
+ if k in cnn_keys:
78
+ torch_obs = (
79
+ torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
80
+ )
81
+ if k in mlp_keys:
82
+ torch_obs = torch_obs.float()
83
+ obs[k] = torch_obs
84
+
85
+ actions = agent.get_actions(obs, greedy=True)
86
+ actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
87
+
88
+ o, _, terminated, truncated, info = env.step(
89
+ actions.cpu().numpy().reshape(env.action_space.shape)
90
+ )
91
+
92
+ if terminated or truncated:
93
+ o, info = env.reset()
94
+ if info["env_done"] or test is True:
95
+ break
96
+
97
+ # Close the environment
98
+ env.close()
99
+
100
+ # Return success
101
+ return 0
102
+
103
+
104
+ if __name__ == "__main__":
105
+ parser = argparse.ArgumentParser()
106
+ parser.add_argument(
107
+ "--cfg_path", type=str, required=True, help="Configuration file"
108
+ )
109
+ parser.add_argument(
110
+ "--checkpoint_path", type=str, default="model", help="Model checkpoint"
111
+ )
112
+ parser.add_argument("--test", action="store_true", help="Test mode")
113
+ opt = parser.parse_args()
114
+ print(opt)
115
+
116
+ main(opt.cfg_path, opt.checkpoint_path, opt.test)
ckpt_1024_0.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a7b985f51b9f2f40182083b57bc785a9972c72227895a14f1d4c764bfa4b8f0
3
+ size 2582118
config.yaml ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_threads: 1
2
+ float32_matmul_precision: high
3
+ dry_run: false
4
+ seed: 42
5
+ torch_use_deterministic_algorithms: false
6
+ torch_backends_cudnn_benchmark: true
7
+ torch_backends_cudnn_deterministic: false
8
+ cublas_workspace_config: null
9
+ exp_name: ppo_doapp
10
+ run_name: 2024-04-15_15-25-55_ppo_doapp_42
11
+ root_dir: ppo/doapp
12
+ algo:
13
+ name: ppo
14
+ total_steps: 1024
15
+ per_rank_batch_size: 16
16
+ run_test: true
17
+ cnn_keys:
18
+ encoder:
19
+ - frame
20
+ mlp_keys:
21
+ encoder:
22
+ - own_character
23
+ - own_health
24
+ - own_side
25
+ - own_wins
26
+ - opp_character
27
+ - opp_health
28
+ - opp_side
29
+ - opp_wins
30
+ - stage
31
+ - timer
32
+ - action
33
+ optimizer:
34
+ _target_: torch.optim.Adam
35
+ lr: 0.005
36
+ eps: 1.0e-06
37
+ weight_decay: 0
38
+ betas:
39
+ - 0.9
40
+ - 0.999
41
+ anneal_lr: false
42
+ gamma: 0.99
43
+ gae_lambda: 0.95
44
+ update_epochs: 1
45
+ loss_reduction: mean
46
+ normalize_advantages: true
47
+ clip_coef: 0.2
48
+ anneal_clip_coef: false
49
+ clip_vloss: false
50
+ ent_coef: 0.0
51
+ anneal_ent_coef: false
52
+ vf_coef: 1.0
53
+ rollout_steps: 32
54
+ dense_units: 16
55
+ mlp_layers: 1
56
+ dense_act: torch.nn.Tanh
57
+ layer_norm: false
58
+ max_grad_norm: 1.0
59
+ encoder:
60
+ cnn_features_dim: 128
61
+ mlp_features_dim: 32
62
+ dense_units: 16
63
+ mlp_layers: 1
64
+ dense_act: torch.nn.Tanh
65
+ layer_norm: false
66
+ actor:
67
+ dense_units: 16
68
+ mlp_layers: 1
69
+ dense_act: torch.nn.Tanh
70
+ layer_norm: false
71
+ critic:
72
+ dense_units: 16
73
+ mlp_layers: 1
74
+ dense_act: torch.nn.Tanh
75
+ layer_norm: false
76
+ buffer:
77
+ size: 32
78
+ memmap: true
79
+ validate_args: false
80
+ from_numpy: false
81
+ share_data: false
82
+ checkpoint:
83
+ every: 100
84
+ resume_from: null
85
+ save_last: true
86
+ keep_last: 5
87
+ distribution:
88
+ validate_args: false
89
+ env:
90
+ id: doapp
91
+ num_envs: 1
92
+ frame_stack: 1
93
+ sync_env: true
94
+ screen_size: 64
95
+ action_repeat: 1
96
+ grayscale: false
97
+ clip_rewards: false
98
+ capture_video: true
99
+ frame_stack_dilation: 1
100
+ max_episode_steps: null
101
+ reward_as_observation: false
102
+ wrapper:
103
+ _target_: sheeprl.envs.diambra.DiambraWrapper
104
+ id: doapp
105
+ action_space: DISCRETE
106
+ screen_size: 64
107
+ grayscale: false
108
+ repeat_action: 1
109
+ rank: null
110
+ log_level: 0
111
+ increase_performance: true
112
+ diambra_settings:
113
+ role: P1
114
+ step_ratio: 6
115
+ difficulty: 4
116
+ continue_game: 0.0
117
+ show_final: false
118
+ outfits: 2
119
+ splash_screen: false
120
+ diambra_wrappers:
121
+ stack_actions: 1
122
+ no_op_max: 0
123
+ no_attack_buttons_combinations: false
124
+ add_last_action: true
125
+ scale: false
126
+ exclude_image_scaling: false
127
+ process_discrete_binary: false
128
+ role_relative: true
129
+ fabric:
130
+ _target_: lightning.fabric.Fabric
131
+ devices: 1
132
+ num_nodes: 1
133
+ strategy: auto
134
+ accelerator: cpu
135
+ precision: 32-true
136
+ callbacks:
137
+ - _target_: sheeprl.utils.callback.CheckpointCallback
138
+ keep_last: 5
139
+ metric:
140
+ log_every: 5000
141
+ disable_timer: false
142
+ log_level: 1
143
+ sync_on_compute: false
144
+ aggregator:
145
+ _target_: sheeprl.utils.metric.MetricAggregator
146
+ raise_on_missing: false
147
+ metrics:
148
+ Rewards/rew_avg:
149
+ _target_: torchmetrics.MeanMetric
150
+ sync_on_compute: false
151
+ Game/ep_len_avg:
152
+ _target_: torchmetrics.MeanMetric
153
+ sync_on_compute: false
154
+ logger:
155
+ _target_: lightning.fabric.loggers.TensorBoardLogger
156
+ name: 2024-04-15_15-25-55_ppo_doapp_42
157
+ root_dir: logs/runs/ppo/doapp
158
+ version: null
159
+ default_hp_metric: true
160
+ prefix: ''
161
+ sub_dir: null
162
+ model_manager:
163
+ disabled: true
164
+ models: {}