|
import matplotlib as mpl |
|
|
|
mpl.use("Agg") |
|
import argparse |
|
import os |
|
import pandas as pd |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
import matplotlib |
|
import IPython |
|
|
|
font = { |
|
"size": 22, |
|
} |
|
matplotlib.rc("font", **font) |
|
sns.set_context("paper", font_scale=2.0) |
|
|
|
|
|
def mkdir_if_missing(dst_dir): |
|
if not os.path.exists(dst_dir): |
|
os.makedirs(dst_dir) |
|
|
|
|
|
def save_figure(name, title=""): |
|
if len(title) > 0: |
|
plt.title(title) |
|
plt.tight_layout() |
|
print(f"output/output_figures/{name[:30]}") |
|
mkdir_if_missing(f"output/output_figures/{name[:30]}") |
|
plt.savefig(f"output/output_figures/{name[:30]}/output.png") |
|
plt.clf() |
|
|
|
|
|
def main(multirun_out, title): |
|
dfs = [] |
|
suffix = "" |
|
run_num = 0 |
|
|
|
for rundir in (sorted(multirun_out.split(","))): |
|
runpath = os.path.join('output/output_stats', rundir) |
|
statspath = os.path.join(runpath, "eval_results.csv") |
|
if os.path.exists(statspath): |
|
run_num += 1 |
|
df = pd.read_csv(statspath) |
|
|
|
|
|
|
|
dfs.append(df) |
|
else: |
|
print("skip:", statspath) |
|
|
|
|
|
df = pd.concat(dfs) |
|
print(df.iloc) |
|
title += f" run: {run_num} " |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(16, 8)) |
|
sns_plot = sns.barplot( |
|
data=df, x="metric", y="success", hue='model', errorbar=("sd", 1), palette="deep" |
|
) |
|
|
|
|
|
for container in ax.containers: |
|
ax.bar_label(container, label_type="center", fontsize="x-large", fmt="%.2f") |
|
|
|
|
|
save_figure(f"{multirun_out}_{title}{suffix}", title) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--multirun_out", type=str) |
|
parser.add_argument("--title", type=str, default="") |
|
|
|
args = parser.parse_args() |
|
main(args.multirun_out, args.title) |
|
|