import matplotlib as mpl mpl.use("Agg") import argparse import os import pandas as pd import seaborn as sns import matplotlib.pyplot as plt import matplotlib import IPython font = { "size": 22, } matplotlib.rc("font", **font) sns.set_context("paper", font_scale=2.0) def mkdir_if_missing(dst_dir): if not os.path.exists(dst_dir): os.makedirs(dst_dir) def save_figure(name, title=""): if len(title) > 0: plt.title(title) plt.tight_layout() print(f"output/output_figures/{name[:30]}") mkdir_if_missing(f"output/output_figures/{name[:30]}") plt.savefig(f"output/output_figures/{name[:30]}/output.png") plt.clf() def main(multirun_out, title): dfs = [] suffix = "" run_num = 0 for rundir in (sorted(multirun_out.split(","))): runpath = os.path.join('output/output_stats', rundir) statspath = os.path.join(runpath, "eval_results.csv") if os.path.exists(statspath): run_num += 1 df = pd.read_csv(statspath) # print(df) # df.drop(df.iloc[-1], axis=0, inplace=True) # df.drop('diversity', axis=1) dfs.append(df) else: print("skip:", statspath) # merge dfs, which have shared column names df = pd.concat(dfs) print(df.iloc) title += f" run: {run_num} " # rewards fig, ax = plt.subplots(figsize=(16, 8)) sns_plot = sns.barplot( data=df, x="metric", y="success", hue='model', errorbar=("sd", 1), palette="deep" ) # label texts for container in ax.containers: ax.bar_label(container, label_type="center", fontsize="x-large", fmt="%.2f") # save plot save_figure(f"{multirun_out}_{title}{suffix}", title) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--multirun_out", type=str) parser.add_argument("--title", type=str, default="") args = parser.parse_args() main(args.multirun_out, args.title)