ALE-Pacman-v5 / plot_improvement.py
ledmands
Added framework for saving charts in plot_improvement.py
71d2358
raw
history blame
No virus
1.61 kB
import argparse
import numpy as np
import os
from matplotlib import pyplot as plt
def calc_stats(filepath):
# load the numpy file
data = np.load(filepath)["results"]
# sort the arrays and delete the first and last elements
data = np.sort(data, axis=1)
data = np.delete(data, -1, axis=1)
data = np.delete(data, 0, axis=1)
avg = round(np.mean(data), 2)
std = round(np.std(data), 2)
return avg, std
# parser = argparse.ArgumentParser()
# parser.add_argument("-f", "--filepath", required=True, help="Specify the file path to the agent.", type=str)
# parser.add_argument("-s", "--save", help="Specify whether to save the chart.", action="store_const", const=True)
# args = parser.parse_args()
# Get the file paths and store in list.
# For each file path, I want to calculate the mean reward. This would be the mean reward for the training run over all evaluations.
# For each file path, append the mean reward to an averages list
# Plot the averages!
filepaths = []
for d in os.listdir("agents/"):
if "dqn_v2" in d:
path = "agents/" + d + "/evaluations.npz"
filepaths.append(path)
means = []
stds = []
for path in filepaths:
avg, std = calc_stats(path)
means.append(avg)
stds.append(std)
runs = []
for i in range(len(filepaths)):
runs.append(i + 1)
plt.xlabel("training runs")
plt.ylabel("score")
plt.bar(runs, means)
plt.bar(runs, stds)
plt.legend(["Mean evaluation score", "Standard deviation"])
plt.title("Average Evaluation Score and Standard Deviation\nAdjusted for Outliers Agent: dqn_v2")
plt.show()
# plt.savefig("charts/fig1")