Michele Milesi
commited on
Commit
•
e184661
1
Parent(s):
5d9a387
feat: variable renaming
Browse files- agent-dreamer_v3.py +4 -4
- agent-ppo.py +4 -4
agent-dreamer_v3.py
CHANGED
@@ -73,13 +73,13 @@ def main(cfg_path: str, checkpoint_path: str, test=False):
|
|
73 |
print("Policy architecture:")
|
74 |
print(agent)
|
75 |
|
76 |
-
|
77 |
# Every time you reset the environment, you must reset the initial states of the model
|
78 |
agent.init_states()
|
79 |
|
80 |
while True:
|
81 |
# Convert numpy observations into torch observations and normalize image observations
|
82 |
-
torch_obs = prepare_obs(fabric,
|
83 |
|
84 |
# Select actions, the agent returns a one-hot categorical or
|
85 |
# more one-hot categorical distributions for muli-discrete actions space
|
@@ -87,12 +87,12 @@ def main(cfg_path: str, checkpoint_path: str, test=False):
|
|
87 |
# Convert actions from one-hot categorical to categorial
|
88 |
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
|
89 |
|
90 |
-
|
91 |
actions.cpu().numpy().reshape(env.action_space.shape)
|
92 |
)
|
93 |
|
94 |
if terminated or truncated:
|
95 |
-
|
96 |
# Every time you reset the environment, you must reset the initial states of the model
|
97 |
agent.init_states()
|
98 |
if info["env_done"] or test is True:
|
|
|
73 |
print("Policy architecture:")
|
74 |
print(agent)
|
75 |
|
76 |
+
obs, info = env.reset()
|
77 |
# Every time you reset the environment, you must reset the initial states of the model
|
78 |
agent.init_states()
|
79 |
|
80 |
while True:
|
81 |
# Convert numpy observations into torch observations and normalize image observations
|
82 |
+
torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys)
|
83 |
|
84 |
# Select actions, the agent returns a one-hot categorical or
|
85 |
# more one-hot categorical distributions for muli-discrete actions space
|
|
|
87 |
# Convert actions from one-hot categorical to categorial
|
88 |
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
|
89 |
|
90 |
+
obs, _, terminated, truncated, info = env.step(
|
91 |
actions.cpu().numpy().reshape(env.action_space.shape)
|
92 |
)
|
93 |
|
94 |
if terminated or truncated:
|
95 |
+
obs, info = env.reset()
|
96 |
# Every time you reset the environment, you must reset the initial states of the model
|
97 |
agent.init_states()
|
98 |
if info["env_done"] or test is True:
|
agent-ppo.py
CHANGED
@@ -65,21 +65,21 @@ def main(cfg_path: str, checkpoint_path: str, test=False):
|
|
65 |
print("Policy architecture:")
|
66 |
print(agent)
|
67 |
|
68 |
-
|
69 |
|
70 |
while True:
|
71 |
# Convert numpy observations into torch observations and normalize image observations
|
72 |
-
torch_obs = prepare_obs(fabric,
|
73 |
|
74 |
actions = agent.get_actions(torch_obs, greedy=True)
|
75 |
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
|
76 |
|
77 |
-
|
78 |
actions.cpu().numpy().reshape(env.action_space.shape)
|
79 |
)
|
80 |
|
81 |
if terminated or truncated:
|
82 |
-
|
83 |
if info["env_done"] or test is True:
|
84 |
break
|
85 |
|
|
|
65 |
print("Policy architecture:")
|
66 |
print(agent)
|
67 |
|
68 |
+
obs, info = env.reset()
|
69 |
|
70 |
while True:
|
71 |
# Convert numpy observations into torch observations and normalize image observations
|
72 |
+
torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys)
|
73 |
|
74 |
actions = agent.get_actions(torch_obs, greedy=True)
|
75 |
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
|
76 |
|
77 |
+
obs, _, terminated, truncated, info = env.step(
|
78 |
actions.cpu().numpy().reshape(env.action_space.shape)
|
79 |
)
|
80 |
|
81 |
if terminated or truncated:
|
82 |
+
obs, info = env.reset()
|
83 |
if info["env_done"] or test is True:
|
84 |
break
|
85 |
|