ledmands commited on
Commit
a30d4ce
1 Parent(s): 866f598

Added functionality to watch_agent.py to support evaluating different agents.

Browse files
Files changed (1) hide show
  1. agents/watch_agent.py +11 -6
agents/watch_agent.py CHANGED
@@ -5,8 +5,6 @@ import gymnasium as gym
5
 
6
  import argparse
7
 
8
- MODEL_NAME = "ALE-Pacman-v5"
9
- loaded_model = DQN.load(MODEL_NAME)
10
 
11
  # This script should have some options
12
  # 1. Turn off the stochasticity as determined by the ALEv5
@@ -18,15 +16,21 @@ loaded_model = DQN.load(MODEL_NAME)
18
  # DONE
19
  # 4. Print the keyword args for the environment? I think this might be helpful...
20
  # IN PROGRESS
21
- # 5.
 
22
 
23
  parser = argparse.ArgumentParser()
24
- parser.add_argument("-r", "--repeat_action_probability", help="repeat action probability", type=float, default=0.25)
25
- parser.add_argument("-f", "--frameskip", help="frameskip", type=int, default=4)
26
  parser.add_argument("-o", "--observe", help="observe agent", action="store_const", const=True)
27
  parser.add_argument("-p", "--print", help="print environment information", action="store_const", const=True)
 
 
28
  args = parser.parse_args()
29
 
 
 
 
30
  # Toggle the render mode based on the -o flag
31
  if args.observe == True:
32
  mode = "human"
@@ -44,6 +48,7 @@ if args.print == True:
44
  for item in env_info:
45
  print(item)
46
  # Evaluate the policy
47
- mean_rwd, std_rwd = evaluate_policy(loaded_model.policy, eval_env, n_eval_episodes=1)
 
48
  print("mean rwd: ", mean_rwd)
49
  print("std rwd: ", std_rwd)
 
5
 
6
  import argparse
7
 
 
 
8
 
9
  # This script should have some options
10
  # 1. Turn off the stochasticity as determined by the ALEv5
 
16
  # DONE
17
  # 4. Print the keyword args for the environment? I think this might be helpful...
18
  # IN PROGRESS
19
+ # 5. Add option flag to accept file path for model
20
+ # 6. Add option flag to accept number of episodes
21
 
22
  parser = argparse.ArgumentParser()
23
+ parser.add_argument("-r", "--repeat_action_probability", help="repeat action probability, default 0.25", type=float, default=0.25)
24
+ parser.add_argument("-f", "--frameskip", help="frameskip, default 4", type=int, default=4)
25
  parser.add_argument("-o", "--observe", help="observe agent", action="store_const", const=True)
26
  parser.add_argument("-p", "--print", help="print environment information", action="store_const", const=True)
27
+ parser.add_argument("-e", "--num_episodes", help="specify the number of episodes to evaluate, default 1", type=int, default=1)
28
+ parser.add_argument("-a", "--agent_filepath", help="file path to agent to watch, minus the .zip extension", type=str, required=True)
29
  args = parser.parse_args()
30
 
31
+ MODEL_NAME = args.agent_filepath
32
+ loaded_model = DQN.load(MODEL_NAME)
33
+
34
  # Toggle the render mode based on the -o flag
35
  if args.observe == True:
36
  mode = "human"
 
48
  for item in env_info:
49
  print(item)
50
  # Evaluate the policy
51
+ mean_rwd, std_rwd = evaluate_policy(loaded_model.policy, eval_env, n_eval_episodes=args.num_episodes)
52
+ print("eval episodes: ", args.num_episodes)
53
  print("mean rwd: ", mean_rwd)
54
  print("std rwd: ", std_rwd)