michele-milesi
commited on
Commit
•
2cbbf46
1
Parent(s):
1dd30c9
Initial Commit
Browse files- agent.py +116 -0
- ckpt_1024_0.ckpt +3 -0
- config.yaml +164 -0
agent.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
|
4 |
+
import gymnasium as gym
|
5 |
+
import torch
|
6 |
+
from lightning import Fabric
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from sheeprl.algos.ppo.agent import build_agent
|
9 |
+
from sheeprl.utils.env import make_env
|
10 |
+
from sheeprl.utils.utils import dotdict
|
11 |
+
|
12 |
+
"""This is an example agent based on SheepRL.
|
13 |
+
|
14 |
+
Usage:
|
15 |
+
diambra run python sheeprl/agent.py --cfg_path "./fake-logs/runs/ppo/doapp/fake-experiment/version_0/config.yaml" --checkpoint_path "./fake-logs/runs/ppo/doapp/fake-experiment/version_0/checkpoint/ckpt_1024_0.ckpt"
|
16 |
+
"""
|
17 |
+
|
18 |
+
|
19 |
+
def main(cfg_path: str, checkpoint_path: str, test=False):
|
20 |
+
# Read the cfg file
|
21 |
+
cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True))
|
22 |
+
print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4))
|
23 |
+
|
24 |
+
# Override configs for evaluation
|
25 |
+
if not test:
|
26 |
+
cfg.env.capture_video = False
|
27 |
+
cfg.env.num_envs = 1
|
28 |
+
|
29 |
+
# Instantiate Fabric
|
30 |
+
precision = getattr(cfg.fabric, "precision", None)
|
31 |
+
plugins = getattr(cfg.fabric, "plugins", None)
|
32 |
+
fabric = Fabric(
|
33 |
+
accelerator="cpu",
|
34 |
+
devices=1,
|
35 |
+
num_nodes=1,
|
36 |
+
precision=precision,
|
37 |
+
plugins=plugins,
|
38 |
+
strategy="auto",
|
39 |
+
)
|
40 |
+
|
41 |
+
# Create Environment
|
42 |
+
env = make_env(cfg, 0, 0)()
|
43 |
+
observation_space = env.observation_space
|
44 |
+
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
|
45 |
+
actions_dim = tuple(
|
46 |
+
env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]
|
47 |
+
)
|
48 |
+
cnn_keys = cfg.algo.cnn_keys.encoder
|
49 |
+
mlp_keys = cfg.algo.mlp_keys.encoder
|
50 |
+
obs_keys = mlp_keys + cnn_keys
|
51 |
+
|
52 |
+
# Load the trained agent
|
53 |
+
state = fabric.load(checkpoint_path)
|
54 |
+
# You need to retrieve only the player
|
55 |
+
agent = build_agent(
|
56 |
+
fabric=fabric,
|
57 |
+
actions_dim=actions_dim,
|
58 |
+
is_continuous=False,
|
59 |
+
cfg=cfg,
|
60 |
+
obs_space=observation_space,
|
61 |
+
agent_state=state["agent"],
|
62 |
+
)[-1]
|
63 |
+
agent.eval()
|
64 |
+
|
65 |
+
# Print policy network architecture
|
66 |
+
print("Policy architecture:")
|
67 |
+
print(agent)
|
68 |
+
|
69 |
+
o, info = env.reset()
|
70 |
+
|
71 |
+
while True:
|
72 |
+
# Convert numpy observations into torch observations and normalize image observations
|
73 |
+
obs = {}
|
74 |
+
for k in o.keys():
|
75 |
+
if k in obs_keys:
|
76 |
+
torch_obs = torch.from_numpy(o[k].copy()).to(fabric.device).unsqueeze(0)
|
77 |
+
if k in cnn_keys:
|
78 |
+
torch_obs = (
|
79 |
+
torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
|
80 |
+
)
|
81 |
+
if k in mlp_keys:
|
82 |
+
torch_obs = torch_obs.float()
|
83 |
+
obs[k] = torch_obs
|
84 |
+
|
85 |
+
actions = agent.get_actions(obs, greedy=True)
|
86 |
+
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
|
87 |
+
|
88 |
+
o, _, terminated, truncated, info = env.step(
|
89 |
+
actions.cpu().numpy().reshape(env.action_space.shape)
|
90 |
+
)
|
91 |
+
|
92 |
+
if terminated or truncated:
|
93 |
+
o, info = env.reset()
|
94 |
+
if info["env_done"] or test is True:
|
95 |
+
break
|
96 |
+
|
97 |
+
# Close the environment
|
98 |
+
env.close()
|
99 |
+
|
100 |
+
# Return success
|
101 |
+
return 0
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
parser = argparse.ArgumentParser()
|
106 |
+
parser.add_argument(
|
107 |
+
"--cfg_path", type=str, required=True, help="Configuration file"
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--checkpoint_path", type=str, default="model", help="Model checkpoint"
|
111 |
+
)
|
112 |
+
parser.add_argument("--test", action="store_true", help="Test mode")
|
113 |
+
opt = parser.parse_args()
|
114 |
+
print(opt)
|
115 |
+
|
116 |
+
main(opt.cfg_path, opt.checkpoint_path, opt.test)
|
ckpt_1024_0.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a7b985f51b9f2f40182083b57bc785a9972c72227895a14f1d4c764bfa4b8f0
|
3 |
+
size 2582118
|
config.yaml
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
num_threads: 1
|
2 |
+
float32_matmul_precision: high
|
3 |
+
dry_run: false
|
4 |
+
seed: 42
|
5 |
+
torch_use_deterministic_algorithms: false
|
6 |
+
torch_backends_cudnn_benchmark: true
|
7 |
+
torch_backends_cudnn_deterministic: false
|
8 |
+
cublas_workspace_config: null
|
9 |
+
exp_name: ppo_doapp
|
10 |
+
run_name: 2024-04-15_15-25-55_ppo_doapp_42
|
11 |
+
root_dir: ppo/doapp
|
12 |
+
algo:
|
13 |
+
name: ppo
|
14 |
+
total_steps: 1024
|
15 |
+
per_rank_batch_size: 16
|
16 |
+
run_test: true
|
17 |
+
cnn_keys:
|
18 |
+
encoder:
|
19 |
+
- frame
|
20 |
+
mlp_keys:
|
21 |
+
encoder:
|
22 |
+
- own_character
|
23 |
+
- own_health
|
24 |
+
- own_side
|
25 |
+
- own_wins
|
26 |
+
- opp_character
|
27 |
+
- opp_health
|
28 |
+
- opp_side
|
29 |
+
- opp_wins
|
30 |
+
- stage
|
31 |
+
- timer
|
32 |
+
- action
|
33 |
+
optimizer:
|
34 |
+
_target_: torch.optim.Adam
|
35 |
+
lr: 0.005
|
36 |
+
eps: 1.0e-06
|
37 |
+
weight_decay: 0
|
38 |
+
betas:
|
39 |
+
- 0.9
|
40 |
+
- 0.999
|
41 |
+
anneal_lr: false
|
42 |
+
gamma: 0.99
|
43 |
+
gae_lambda: 0.95
|
44 |
+
update_epochs: 1
|
45 |
+
loss_reduction: mean
|
46 |
+
normalize_advantages: true
|
47 |
+
clip_coef: 0.2
|
48 |
+
anneal_clip_coef: false
|
49 |
+
clip_vloss: false
|
50 |
+
ent_coef: 0.0
|
51 |
+
anneal_ent_coef: false
|
52 |
+
vf_coef: 1.0
|
53 |
+
rollout_steps: 32
|
54 |
+
dense_units: 16
|
55 |
+
mlp_layers: 1
|
56 |
+
dense_act: torch.nn.Tanh
|
57 |
+
layer_norm: false
|
58 |
+
max_grad_norm: 1.0
|
59 |
+
encoder:
|
60 |
+
cnn_features_dim: 128
|
61 |
+
mlp_features_dim: 32
|
62 |
+
dense_units: 16
|
63 |
+
mlp_layers: 1
|
64 |
+
dense_act: torch.nn.Tanh
|
65 |
+
layer_norm: false
|
66 |
+
actor:
|
67 |
+
dense_units: 16
|
68 |
+
mlp_layers: 1
|
69 |
+
dense_act: torch.nn.Tanh
|
70 |
+
layer_norm: false
|
71 |
+
critic:
|
72 |
+
dense_units: 16
|
73 |
+
mlp_layers: 1
|
74 |
+
dense_act: torch.nn.Tanh
|
75 |
+
layer_norm: false
|
76 |
+
buffer:
|
77 |
+
size: 32
|
78 |
+
memmap: true
|
79 |
+
validate_args: false
|
80 |
+
from_numpy: false
|
81 |
+
share_data: false
|
82 |
+
checkpoint:
|
83 |
+
every: 100
|
84 |
+
resume_from: null
|
85 |
+
save_last: true
|
86 |
+
keep_last: 5
|
87 |
+
distribution:
|
88 |
+
validate_args: false
|
89 |
+
env:
|
90 |
+
id: doapp
|
91 |
+
num_envs: 1
|
92 |
+
frame_stack: 1
|
93 |
+
sync_env: true
|
94 |
+
screen_size: 64
|
95 |
+
action_repeat: 1
|
96 |
+
grayscale: false
|
97 |
+
clip_rewards: false
|
98 |
+
capture_video: true
|
99 |
+
frame_stack_dilation: 1
|
100 |
+
max_episode_steps: null
|
101 |
+
reward_as_observation: false
|
102 |
+
wrapper:
|
103 |
+
_target_: sheeprl.envs.diambra.DiambraWrapper
|
104 |
+
id: doapp
|
105 |
+
action_space: DISCRETE
|
106 |
+
screen_size: 64
|
107 |
+
grayscale: false
|
108 |
+
repeat_action: 1
|
109 |
+
rank: null
|
110 |
+
log_level: 0
|
111 |
+
increase_performance: true
|
112 |
+
diambra_settings:
|
113 |
+
role: P1
|
114 |
+
step_ratio: 6
|
115 |
+
difficulty: 4
|
116 |
+
continue_game: 0.0
|
117 |
+
show_final: false
|
118 |
+
outfits: 2
|
119 |
+
splash_screen: false
|
120 |
+
diambra_wrappers:
|
121 |
+
stack_actions: 1
|
122 |
+
no_op_max: 0
|
123 |
+
no_attack_buttons_combinations: false
|
124 |
+
add_last_action: true
|
125 |
+
scale: false
|
126 |
+
exclude_image_scaling: false
|
127 |
+
process_discrete_binary: false
|
128 |
+
role_relative: true
|
129 |
+
fabric:
|
130 |
+
_target_: lightning.fabric.Fabric
|
131 |
+
devices: 1
|
132 |
+
num_nodes: 1
|
133 |
+
strategy: auto
|
134 |
+
accelerator: cpu
|
135 |
+
precision: 32-true
|
136 |
+
callbacks:
|
137 |
+
- _target_: sheeprl.utils.callback.CheckpointCallback
|
138 |
+
keep_last: 5
|
139 |
+
metric:
|
140 |
+
log_every: 5000
|
141 |
+
disable_timer: false
|
142 |
+
log_level: 1
|
143 |
+
sync_on_compute: false
|
144 |
+
aggregator:
|
145 |
+
_target_: sheeprl.utils.metric.MetricAggregator
|
146 |
+
raise_on_missing: false
|
147 |
+
metrics:
|
148 |
+
Rewards/rew_avg:
|
149 |
+
_target_: torchmetrics.MeanMetric
|
150 |
+
sync_on_compute: false
|
151 |
+
Game/ep_len_avg:
|
152 |
+
_target_: torchmetrics.MeanMetric
|
153 |
+
sync_on_compute: false
|
154 |
+
logger:
|
155 |
+
_target_: lightning.fabric.loggers.TensorBoardLogger
|
156 |
+
name: 2024-04-15_15-25-55_ppo_doapp_42
|
157 |
+
root_dir: logs/runs/ppo/doapp
|
158 |
+
version: null
|
159 |
+
default_hp_metric: true
|
160 |
+
prefix: ''
|
161 |
+
sub_dir: null
|
162 |
+
model_manager:
|
163 |
+
disabled: true
|
164 |
+
models: {}
|