ledmands commited on
Commit
e036817
1 Parent(s): 2690fb6

plot improvement is a mess, but getting there

Browse files
Files changed (1) hide show
  1. plot_improvement.py +41 -4
plot_improvement.py CHANGED
@@ -1,5 +1,6 @@
1
  import argparse
2
  from numpy import load, ndarray
 
3
 
4
  parser = argparse.ArgumentParser()
5
  parser.add_argument("-f", "--filepath", required=True, help="Specify the file path to the agent.", type=str)
@@ -8,10 +9,7 @@ args = parser.parse_args()
8
  filepath = args.filepath
9
  npdata = load(filepath)
10
 
11
- print(type(npdata['results']))
12
  evaluations = ndarray.tolist(npdata['results'])
13
- print(type(evaluations))
14
- print(len(evaluations))
15
  # print(evaluations)
16
  sorted_evals = []
17
  for eval in evaluations:
@@ -39,4 +37,43 @@ print("num evals: " + str(len(mean_eval_rewards)))
39
  # The number of elements is going to vary for each training run
40
  # The number of evaluation episodes will be constant, 10.
41
  # I need to convert to a regular list first
42
- # I could iterate over each element
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
  from numpy import load, ndarray
3
+ import os
4
 
5
  parser = argparse.ArgumentParser()
6
  parser.add_argument("-f", "--filepath", required=True, help="Specify the file path to the agent.", type=str)
 
9
  filepath = args.filepath
10
  npdata = load(filepath)
11
 
 
12
  evaluations = ndarray.tolist(npdata['results'])
 
 
13
  # print(evaluations)
14
  sorted_evals = []
15
  for eval in evaluations:
 
37
  # The number of elements is going to vary for each training run
38
  # The number of evaluation episodes will be constant, 10.
39
  # I need to convert to a regular list first
40
+ # I could iterate over each element
41
+
42
+ agent_dirs = []
43
+ for d in os.listdir("agents/"):
44
+ if "dqn_v2" in d:
45
+ agent_dirs.append(d)
46
+ # Now I have a list of dirs with the evals.
47
+ # Iterate over the dirs, append the file path, load the evals, calculate the average score of the eval, then return a list with averages
48
+ eval_list = []
49
+ for d in agent_dirs:
50
+ path = "agents/" + d + "/evaluations.npz"
51
+ evals = ndarray.tolist(load(path)["results"])
52
+ eval_list.append(evals)
53
+ # for i in eval_list:
54
+ # print(i)
55
+ # print()
56
+
57
+ def remove_outliers(evals: list) -> list:
58
+ trimmed = []
59
+ for eval in evals:
60
+ eval.sort()
61
+ eval.pop(0)
62
+ eval.pop()
63
+ trimmed.append(eval)
64
+ return trimmed
65
+
66
+ avgs = [[]]
67
+ index = 0
68
+ for i in eval_list:
69
+ avgs.append(i)
70
+ for j in i:
71
+ j.sort()
72
+ j.pop()
73
+ j.pop(0)
74
+ avgs[index].append(sum(j) / len(j))
75
+ index += 1
76
+
77
+ print(avgs)
78
+
79
+