Michele Milesi
commited on
Commit
•
6b39341
1
Parent(s):
12794d0
feat: added dv3
Browse files
agent-dreamer_v3.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
|
4 |
+
import gymnasium as gym
|
5 |
+
import torch
|
6 |
+
from lightning import Fabric
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from sheeprl.algos.dreamer_v3.agent import build_agent
|
9 |
+
from sheeprl.utils.env import make_env
|
10 |
+
from sheeprl.utils.utils import dotdict
|
11 |
+
|
12 |
+
"""This is an example agent based on SheepRL.
|
13 |
+
|
14 |
+
Usage:
|
15 |
+
cd sheeprl
|
16 |
+
diambra run python agent-dreamer_v3.py --cfg_path "./fake-logs/runs/dreamer_v3/doapp/fake-experiment/version_0/config.yaml" --checkpoint_path "./fake-logs/runs/dreamer_v3/doapp/fake-experiment/version_0/checkpoint/ckpt_1024_0.ckpt"
|
17 |
+
"""
|
18 |
+
|
19 |
+
|
20 |
+
def main(cfg_path: str, checkpoint_path: str, test=False):
|
21 |
+
# Read the cfg file
|
22 |
+
cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True))
|
23 |
+
print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4))
|
24 |
+
|
25 |
+
# Override configs for evaluation
|
26 |
+
# You do not need to capture the video since you are submitting the agent and the video is recorded by DIAMBRA
|
27 |
+
cfg.env.capture_video = False
|
28 |
+
|
29 |
+
# Instantiate Fabric
|
30 |
+
# You must use the same precision and plugins used for training.
|
31 |
+
precision = getattr(cfg.fabric, "precision", None)
|
32 |
+
plugins = getattr(cfg.fabric, "plugins", None)
|
33 |
+
fabric = Fabric(
|
34 |
+
accelerator="auto",
|
35 |
+
devices=1,
|
36 |
+
num_nodes=1,
|
37 |
+
precision=precision,
|
38 |
+
plugins=plugins,
|
39 |
+
strategy="auto",
|
40 |
+
)
|
41 |
+
|
42 |
+
# Create Environment
|
43 |
+
env = make_env(cfg, 0, 0)()
|
44 |
+
observation_space = env.observation_space
|
45 |
+
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
|
46 |
+
actions_dim = tuple(
|
47 |
+
env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]
|
48 |
+
)
|
49 |
+
cnn_keys = cfg.algo.cnn_keys.encoder
|
50 |
+
mlp_keys = cfg.algo.mlp_keys.encoder
|
51 |
+
obs_keys = mlp_keys + cnn_keys
|
52 |
+
|
53 |
+
# Load the trained agent
|
54 |
+
state = fabric.load(checkpoint_path)
|
55 |
+
# You need to retrieve only the player
|
56 |
+
# Check for each algorithm what models the `build_agent()` function returns
|
57 |
+
# (placed in the `agent.py` file of the algorithm), and which arguments it needs.
|
58 |
+
# Check also which are the keys of the checkpoint: if the `build_agent()` parameter
|
59 |
+
# is called `model_state`, then you retrieve the model state with `state["model"]`.
|
60 |
+
agent = build_agent(
|
61 |
+
fabric=fabric,
|
62 |
+
actions_dim=actions_dim,
|
63 |
+
is_continuous=False,
|
64 |
+
cfg=cfg,
|
65 |
+
obs_space=observation_space,
|
66 |
+
world_model_state=state["world_model"],
|
67 |
+
actor_state=state["actor"],
|
68 |
+
critic_state=state["critic"],
|
69 |
+
target_critic_state=state["target_critic"],
|
70 |
+
)[-1]
|
71 |
+
agent.eval()
|
72 |
+
|
73 |
+
# Print policy network architecture
|
74 |
+
print("Policy architecture:")
|
75 |
+
print(agent)
|
76 |
+
|
77 |
+
o, info = env.reset()
|
78 |
+
|
79 |
+
while True:
|
80 |
+
# Convert numpy observations into torch observations and normalize image observations
|
81 |
+
# Every algorithm has its own way to do it, check in the test function of the algorithm
|
82 |
+
# which is the correct way to it.
|
83 |
+
# Check the `test()` function called in the `evaluate.py` file of the algorithm.
|
84 |
+
obs = {}
|
85 |
+
for k in obs_keys:
|
86 |
+
obs[k] = (
|
87 |
+
torch.from_numpy(o[k]).to(fabric.device).view(1, 1, *o[k].shape).float()
|
88 |
+
)
|
89 |
+
if k in cnn_keys:
|
90 |
+
obs[k] = obs[k] / 255 - 0.5
|
91 |
+
|
92 |
+
# Select actions, the agent returns a one-hot categorical or
|
93 |
+
# more one-hot categorical distributions for muli-discrete actions space
|
94 |
+
actions = agent.get_actions(obs, greedy=False)
|
95 |
+
# Convert actions from one-hot categorical to categorial
|
96 |
+
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
|
97 |
+
|
98 |
+
o, _, terminated, truncated, info = env.step(
|
99 |
+
actions.cpu().numpy().reshape(env.action_space.shape)
|
100 |
+
)
|
101 |
+
|
102 |
+
if terminated or truncated:
|
103 |
+
o, info = env.reset()
|
104 |
+
if info["env_done"] or test is True:
|
105 |
+
break
|
106 |
+
|
107 |
+
# Close the environment
|
108 |
+
env.close()
|
109 |
+
|
110 |
+
# Return success
|
111 |
+
return 0
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
parser = argparse.ArgumentParser()
|
116 |
+
parser.add_argument(
|
117 |
+
"--cfg_path", type=str, required=True, help="Configuration file"
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--checkpoint_path", type=str, default="model", help="Model checkpoint"
|
121 |
+
)
|
122 |
+
parser.add_argument("--test", action="store_true", help="Test mode")
|
123 |
+
opt = parser.parse_args()
|
124 |
+
print(opt)
|
125 |
+
|
126 |
+
main(opt.cfg_path, opt.checkpoint_path, opt.test)
|
results/dreamer_v3/ckpt_1024_0.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0f9a9b2bccb05f94a374a010446b009febd8ce8ae63105aec1455cf99c5b4cdc
|
3 |
+
size 389012
|
results/dreamer_v3/config.yaml
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
num_threads: 1
|
2 |
+
float32_matmul_precision: high
|
3 |
+
dry_run: false
|
4 |
+
seed: 42
|
5 |
+
torch_use_deterministic_algorithms: false
|
6 |
+
torch_backends_cudnn_benchmark: true
|
7 |
+
torch_backends_cudnn_deterministic: false
|
8 |
+
cublas_workspace_config: null
|
9 |
+
exp_name: dreamer_v3_doapp
|
10 |
+
run_name: 2024-04-16_17-34-17_dreamer_v3_doapp_42
|
11 |
+
root_dir: dreamer_v3/doapp
|
12 |
+
algo:
|
13 |
+
name: dreamer_v3
|
14 |
+
total_steps: 1024
|
15 |
+
per_rank_batch_size: 2
|
16 |
+
run_test: false
|
17 |
+
cnn_keys:
|
18 |
+
encoder:
|
19 |
+
- frame
|
20 |
+
decoder:
|
21 |
+
- frame
|
22 |
+
mlp_keys:
|
23 |
+
encoder:
|
24 |
+
- own_character
|
25 |
+
- own_health
|
26 |
+
- own_side
|
27 |
+
- own_wins
|
28 |
+
- opp_character
|
29 |
+
- opp_health
|
30 |
+
- opp_side
|
31 |
+
- opp_wins
|
32 |
+
- stage
|
33 |
+
- timer
|
34 |
+
- action
|
35 |
+
decoder:
|
36 |
+
- own_character
|
37 |
+
- own_health
|
38 |
+
- own_side
|
39 |
+
- own_wins
|
40 |
+
- opp_character
|
41 |
+
- opp_health
|
42 |
+
- opp_side
|
43 |
+
- opp_wins
|
44 |
+
- stage
|
45 |
+
- timer
|
46 |
+
- action
|
47 |
+
world_model:
|
48 |
+
optimizer:
|
49 |
+
_target_: torch.optim.Adam
|
50 |
+
lr: 0.0001
|
51 |
+
eps: 1.0e-08
|
52 |
+
weight_decay: 0
|
53 |
+
betas:
|
54 |
+
- 0.9
|
55 |
+
- 0.999
|
56 |
+
discrete_size: 4
|
57 |
+
stochastic_size: 4
|
58 |
+
kl_dynamic: 0.5
|
59 |
+
kl_representation: 0.1
|
60 |
+
kl_free_nats: 1.0
|
61 |
+
kl_regularizer: 1.0
|
62 |
+
continue_scale_factor: 1.0
|
63 |
+
clip_gradients: 1000.0
|
64 |
+
decoupled_rssm: false
|
65 |
+
learnable_initial_recurrent_state: true
|
66 |
+
encoder:
|
67 |
+
cnn_channels_multiplier: 2
|
68 |
+
cnn_act: torch.nn.SiLU
|
69 |
+
dense_act: torch.nn.SiLU
|
70 |
+
mlp_layers: 1
|
71 |
+
cnn_layer_norm:
|
72 |
+
cls: sheeprl.models.models.LayerNormChannelLast
|
73 |
+
kw:
|
74 |
+
eps: 0.001
|
75 |
+
mlp_layer_norm:
|
76 |
+
cls: sheeprl.models.models.LayerNorm
|
77 |
+
kw:
|
78 |
+
eps: 0.001
|
79 |
+
dense_units: 8
|
80 |
+
recurrent_model:
|
81 |
+
recurrent_state_size: 8
|
82 |
+
layer_norm:
|
83 |
+
cls: sheeprl.models.models.LayerNorm
|
84 |
+
kw:
|
85 |
+
eps: 0.001
|
86 |
+
dense_units: 8
|
87 |
+
transition_model:
|
88 |
+
hidden_size: 8
|
89 |
+
dense_act: torch.nn.SiLU
|
90 |
+
layer_norm:
|
91 |
+
cls: sheeprl.models.models.LayerNorm
|
92 |
+
kw:
|
93 |
+
eps: 0.001
|
94 |
+
representation_model:
|
95 |
+
hidden_size: 8
|
96 |
+
dense_act: torch.nn.SiLU
|
97 |
+
layer_norm:
|
98 |
+
cls: sheeprl.models.models.LayerNorm
|
99 |
+
kw:
|
100 |
+
eps: 0.001
|
101 |
+
observation_model:
|
102 |
+
cnn_channels_multiplier: 2
|
103 |
+
cnn_act: torch.nn.SiLU
|
104 |
+
dense_act: torch.nn.SiLU
|
105 |
+
mlp_layers: 1
|
106 |
+
cnn_layer_norm:
|
107 |
+
cls: sheeprl.models.models.LayerNormChannelLast
|
108 |
+
kw:
|
109 |
+
eps: 0.001
|
110 |
+
mlp_layer_norm:
|
111 |
+
cls: sheeprl.models.models.LayerNorm
|
112 |
+
kw:
|
113 |
+
eps: 0.001
|
114 |
+
dense_units: 8
|
115 |
+
reward_model:
|
116 |
+
dense_act: torch.nn.SiLU
|
117 |
+
mlp_layers: 1
|
118 |
+
layer_norm:
|
119 |
+
cls: sheeprl.models.models.LayerNorm
|
120 |
+
kw:
|
121 |
+
eps: 0.001
|
122 |
+
dense_units: 8
|
123 |
+
bins: 255
|
124 |
+
discount_model:
|
125 |
+
learnable: true
|
126 |
+
dense_act: torch.nn.SiLU
|
127 |
+
mlp_layers: 1
|
128 |
+
layer_norm:
|
129 |
+
cls: sheeprl.models.models.LayerNorm
|
130 |
+
kw:
|
131 |
+
eps: 0.001
|
132 |
+
dense_units: 8
|
133 |
+
actor:
|
134 |
+
optimizer:
|
135 |
+
_target_: torch.optim.Adam
|
136 |
+
lr: 8.0e-05
|
137 |
+
eps: 1.0e-05
|
138 |
+
weight_decay: 0
|
139 |
+
betas:
|
140 |
+
- 0.9
|
141 |
+
- 0.999
|
142 |
+
cls: sheeprl.algos.dreamer_v3.agent.Actor
|
143 |
+
ent_coef: 0.0003
|
144 |
+
min_std: 0.1
|
145 |
+
max_std: 1.0
|
146 |
+
init_std: 2.0
|
147 |
+
dense_act: torch.nn.SiLU
|
148 |
+
mlp_layers: 1
|
149 |
+
layer_norm:
|
150 |
+
cls: sheeprl.models.models.LayerNorm
|
151 |
+
kw:
|
152 |
+
eps: 0.001
|
153 |
+
dense_units: 8
|
154 |
+
clip_gradients: 100.0
|
155 |
+
unimix: 0.01
|
156 |
+
action_clip: 1.0
|
157 |
+
moments:
|
158 |
+
decay: 0.99
|
159 |
+
max: 1.0
|
160 |
+
percentile:
|
161 |
+
low: 0.05
|
162 |
+
high: 0.95
|
163 |
+
critic:
|
164 |
+
optimizer:
|
165 |
+
_target_: torch.optim.Adam
|
166 |
+
lr: 8.0e-05
|
167 |
+
eps: 1.0e-05
|
168 |
+
weight_decay: 0
|
169 |
+
betas:
|
170 |
+
- 0.9
|
171 |
+
- 0.999
|
172 |
+
dense_act: torch.nn.SiLU
|
173 |
+
mlp_layers: 1
|
174 |
+
layer_norm:
|
175 |
+
cls: sheeprl.models.models.LayerNorm
|
176 |
+
kw:
|
177 |
+
eps: 0.001
|
178 |
+
dense_units: 8
|
179 |
+
per_rank_target_network_update_freq: 1
|
180 |
+
tau: 0.02
|
181 |
+
bins: 255
|
182 |
+
clip_gradients: 100.0
|
183 |
+
gamma: 0.996996996996997
|
184 |
+
lmbda: 0.95
|
185 |
+
horizon: 15
|
186 |
+
replay_ratio: 0.0625
|
187 |
+
learning_starts: 1024
|
188 |
+
per_rank_pretrain_steps: 0
|
189 |
+
per_rank_sequence_length: 64
|
190 |
+
cnn_layer_norm:
|
191 |
+
cls: sheeprl.models.models.LayerNormChannelLast
|
192 |
+
kw:
|
193 |
+
eps: 0.001
|
194 |
+
mlp_layer_norm:
|
195 |
+
cls: sheeprl.models.models.LayerNorm
|
196 |
+
kw:
|
197 |
+
eps: 0.001
|
198 |
+
dense_units: 8
|
199 |
+
mlp_layers: 1
|
200 |
+
dense_act: torch.nn.SiLU
|
201 |
+
cnn_act: torch.nn.SiLU
|
202 |
+
unimix: 0.01
|
203 |
+
hafner_initialization: true
|
204 |
+
player:
|
205 |
+
discrete_size: 4
|
206 |
+
buffer:
|
207 |
+
size: 1024
|
208 |
+
memmap: true
|
209 |
+
validate_args: false
|
210 |
+
from_numpy: false
|
211 |
+
checkpoint: true
|
212 |
+
checkpoint:
|
213 |
+
every: 10000
|
214 |
+
resume_from: null
|
215 |
+
save_last: true
|
216 |
+
keep_last: 5
|
217 |
+
distribution:
|
218 |
+
validate_args: false
|
219 |
+
type: auto
|
220 |
+
env:
|
221 |
+
id: doapp
|
222 |
+
num_envs: 1
|
223 |
+
frame_stack: -1
|
224 |
+
sync_env: true
|
225 |
+
screen_size: 64
|
226 |
+
action_repeat: 1
|
227 |
+
grayscale: false
|
228 |
+
clip_rewards: false
|
229 |
+
capture_video: true
|
230 |
+
frame_stack_dilation: 1
|
231 |
+
max_episode_steps: null
|
232 |
+
reward_as_observation: false
|
233 |
+
wrapper:
|
234 |
+
_target_: sheeprl.envs.diambra.DiambraWrapper
|
235 |
+
id: doapp
|
236 |
+
action_space: DISCRETE
|
237 |
+
screen_size: 64
|
238 |
+
grayscale: false
|
239 |
+
repeat_action: 1
|
240 |
+
rank: null
|
241 |
+
log_level: 0
|
242 |
+
increase_performance: true
|
243 |
+
diambra_settings:
|
244 |
+
role: P1
|
245 |
+
step_ratio: 6
|
246 |
+
difficulty: 4
|
247 |
+
continue_game: 0.0
|
248 |
+
show_final: false
|
249 |
+
outfits: 2
|
250 |
+
splash_screen: false
|
251 |
+
diambra_wrappers:
|
252 |
+
stack_actions: 1
|
253 |
+
no_op_max: 0
|
254 |
+
no_attack_buttons_combinations: false
|
255 |
+
add_last_action: true
|
256 |
+
scale: false
|
257 |
+
exclude_image_scaling: false
|
258 |
+
process_discrete_binary: false
|
259 |
+
role_relative: true
|
260 |
+
fabric:
|
261 |
+
_target_: lightning.fabric.Fabric
|
262 |
+
devices: 1
|
263 |
+
num_nodes: 1
|
264 |
+
strategy: auto
|
265 |
+
accelerator: cpu
|
266 |
+
precision: 32-true
|
267 |
+
callbacks:
|
268 |
+
- _target_: sheeprl.utils.callback.CheckpointCallback
|
269 |
+
keep_last: 5
|
270 |
+
metric:
|
271 |
+
log_every: 5000
|
272 |
+
disable_timer: false
|
273 |
+
log_level: 1
|
274 |
+
sync_on_compute: false
|
275 |
+
aggregator:
|
276 |
+
_target_: sheeprl.utils.metric.MetricAggregator
|
277 |
+
raise_on_missing: false
|
278 |
+
metrics:
|
279 |
+
Rewards/rew_avg:
|
280 |
+
_target_: torchmetrics.MeanMetric
|
281 |
+
sync_on_compute: false
|
282 |
+
Game/ep_len_avg:
|
283 |
+
_target_: torchmetrics.MeanMetric
|
284 |
+
sync_on_compute: false
|
285 |
+
Loss/world_model_loss:
|
286 |
+
_target_: torchmetrics.MeanMetric
|
287 |
+
sync_on_compute: false
|
288 |
+
Loss/value_loss:
|
289 |
+
_target_: torchmetrics.MeanMetric
|
290 |
+
sync_on_compute: false
|
291 |
+
Loss/policy_loss:
|
292 |
+
_target_: torchmetrics.MeanMetric
|
293 |
+
sync_on_compute: false
|
294 |
+
Loss/observation_loss:
|
295 |
+
_target_: torchmetrics.MeanMetric
|
296 |
+
sync_on_compute: false
|
297 |
+
Loss/reward_loss:
|
298 |
+
_target_: torchmetrics.MeanMetric
|
299 |
+
sync_on_compute: false
|
300 |
+
Loss/state_loss:
|
301 |
+
_target_: torchmetrics.MeanMetric
|
302 |
+
sync_on_compute: false
|
303 |
+
Loss/continue_loss:
|
304 |
+
_target_: torchmetrics.MeanMetric
|
305 |
+
sync_on_compute: false
|
306 |
+
State/kl:
|
307 |
+
_target_: torchmetrics.MeanMetric
|
308 |
+
sync_on_compute: false
|
309 |
+
State/post_entropy:
|
310 |
+
_target_: torchmetrics.MeanMetric
|
311 |
+
sync_on_compute: false
|
312 |
+
State/prior_entropy:
|
313 |
+
_target_: torchmetrics.MeanMetric
|
314 |
+
sync_on_compute: false
|
315 |
+
Grads/world_model:
|
316 |
+
_target_: torchmetrics.MeanMetric
|
317 |
+
sync_on_compute: false
|
318 |
+
Grads/actor:
|
319 |
+
_target_: torchmetrics.MeanMetric
|
320 |
+
sync_on_compute: false
|
321 |
+
Grads/critic:
|
322 |
+
_target_: torchmetrics.MeanMetric
|
323 |
+
sync_on_compute: false
|
324 |
+
logger:
|
325 |
+
_target_: lightning.fabric.loggers.TensorBoardLogger
|
326 |
+
name: 2024-04-16_17-34-17_dreamer_v3_doapp_42
|
327 |
+
root_dir: logs/runs/dreamer_v3/doapp
|
328 |
+
version: null
|
329 |
+
default_hp_metric: true
|
330 |
+
prefix: ''
|
331 |
+
sub_dir: null
|
332 |
+
model_manager:
|
333 |
+
disabled: true
|
334 |
+
models:
|
335 |
+
world_model:
|
336 |
+
model_name: dreamer_v3_doapp_world_model
|
337 |
+
description: DreamerV3 World Model used in doapp Environment
|
338 |
+
tags: {}
|
339 |
+
actor:
|
340 |
+
model_name: dreamer_v3_doapp_actor
|
341 |
+
description: DreamerV3 Actor used in doapp Environment
|
342 |
+
tags: {}
|
343 |
+
critic:
|
344 |
+
model_name: dreamer_v3_doapp_critic
|
345 |
+
description: DreamerV3 Critic used in doapp Environment
|
346 |
+
tags: {}
|
347 |
+
target_critic:
|
348 |
+
model_name: dreamer_v3_doapp_target_critic
|
349 |
+
description: DreamerV3 Target Critic used in doapp Environment
|
350 |
+
tags: {}
|
351 |
+
moments:
|
352 |
+
model_name: dreamer_v3_doapp_moments
|
353 |
+
description: DreamerV3 Moments used in doapp Environment
|
354 |
+
tags: {}
|
ckpt_1024_0.ckpt → results/ppo/ckpt_1024_0.ckpt
RENAMED
File without changes
|
config.yaml → results/ppo/config.yaml
RENAMED
File without changes
|