Michele Milesi commited on
Commit
e184661
1 Parent(s): 5d9a387

feat: variable renaming

Browse files
Files changed (2) hide show
  1. agent-dreamer_v3.py +4 -4
  2. agent-ppo.py +4 -4
agent-dreamer_v3.py CHANGED
@@ -73,13 +73,13 @@ def main(cfg_path: str, checkpoint_path: str, test=False):
73
  print("Policy architecture:")
74
  print(agent)
75
 
76
- o, info = env.reset()
77
  # Every time you reset the environment, you must reset the initial states of the model
78
  agent.init_states()
79
 
80
  while True:
81
  # Convert numpy observations into torch observations and normalize image observations
82
- torch_obs = prepare_obs(fabric, o, cnn_keys=cnn_keys)
83
 
84
  # Select actions, the agent returns a one-hot categorical or
85
  # more one-hot categorical distributions for muli-discrete actions space
@@ -87,12 +87,12 @@ def main(cfg_path: str, checkpoint_path: str, test=False):
87
  # Convert actions from one-hot categorical to categorial
88
  actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
89
 
90
- o, _, terminated, truncated, info = env.step(
91
  actions.cpu().numpy().reshape(env.action_space.shape)
92
  )
93
 
94
  if terminated or truncated:
95
- o, info = env.reset()
96
  # Every time you reset the environment, you must reset the initial states of the model
97
  agent.init_states()
98
  if info["env_done"] or test is True:
 
73
  print("Policy architecture:")
74
  print(agent)
75
 
76
+ obs, info = env.reset()
77
  # Every time you reset the environment, you must reset the initial states of the model
78
  agent.init_states()
79
 
80
  while True:
81
  # Convert numpy observations into torch observations and normalize image observations
82
+ torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys)
83
 
84
  # Select actions, the agent returns a one-hot categorical or
85
  # more one-hot categorical distributions for muli-discrete actions space
 
87
  # Convert actions from one-hot categorical to categorial
88
  actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
89
 
90
+ obs, _, terminated, truncated, info = env.step(
91
  actions.cpu().numpy().reshape(env.action_space.shape)
92
  )
93
 
94
  if terminated or truncated:
95
+ obs, info = env.reset()
96
  # Every time you reset the environment, you must reset the initial states of the model
97
  agent.init_states()
98
  if info["env_done"] or test is True:
agent-ppo.py CHANGED
@@ -65,21 +65,21 @@ def main(cfg_path: str, checkpoint_path: str, test=False):
65
  print("Policy architecture:")
66
  print(agent)
67
 
68
- o, info = env.reset()
69
 
70
  while True:
71
  # Convert numpy observations into torch observations and normalize image observations
72
- torch_obs = prepare_obs(fabric, o, cnn_keys=cnn_keys)
73
 
74
  actions = agent.get_actions(torch_obs, greedy=True)
75
  actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
76
 
77
- o, _, terminated, truncated, info = env.step(
78
  actions.cpu().numpy().reshape(env.action_space.shape)
79
  )
80
 
81
  if terminated or truncated:
82
- o, info = env.reset()
83
  if info["env_done"] or test is True:
84
  break
85
 
 
65
  print("Policy architecture:")
66
  print(agent)
67
 
68
+ obs, info = env.reset()
69
 
70
  while True:
71
  # Convert numpy observations into torch observations and normalize image observations
72
+ torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys)
73
 
74
  actions = agent.get_actions(torch_obs, greedy=True)
75
  actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
76
 
77
+ obs, _, terminated, truncated, info = env.step(
78
  actions.cpu().numpy().reshape(env.action_space.shape)
79
  )
80
 
81
  if terminated or truncated:
82
+ obs, info = env.reset()
83
  if info["env_done"] or test is True:
84
  break
85