|
import os |
|
import sys |
|
import json |
|
|
|
from cliport import agents |
|
from cliport import tasks |
|
import argparse |
|
import datetime |
|
import matplotlib as mpl |
|
|
|
mpl.use("Agg") |
|
import argparse |
|
import os |
|
import pandas as pd |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
import matplotlib |
|
import IPython |
|
import numpy as np |
|
font = { |
|
"size": 22, |
|
} |
|
matplotlib.rc("font", **font) |
|
sns.set_context("paper", font_scale=2.0) |
|
|
|
|
|
def mkdir_if_missing(dst_dir): |
|
if not os.path.exists(dst_dir): |
|
os.makedirs(dst_dir) |
|
|
|
|
|
def save_figure(name, title=""): |
|
print(f"output/output_figures/{name}.png") |
|
if len(title) > 0: |
|
plt.title(title) |
|
plt.tight_layout() |
|
mkdir_if_missing(f"output/output_figures/{name}") |
|
plt.savefig(f"output/output_figures/{name}/output.png") |
|
plt.clf() |
|
|
|
|
|
def print_and_write(file_handle, text): |
|
print(text) |
|
if file_handle is not None: |
|
file_handle.write(text + "\n") |
|
return text |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument( |
|
"--results", "-r", type=str, default="exps/exps-singletask" |
|
) |
|
parser.add_argument( |
|
"--single", "-s", action="store_true", default=False |
|
) |
|
args = parser.parse_args() |
|
|
|
root_folder = os.environ['GENSIM_ROOT'] |
|
exp_folder = os.path.join(root_folder, args.results) |
|
|
|
|
|
mkdir_if_missing('output/output_figures') |
|
mkdir_if_missing('output/cliport_output') |
|
mkdir_if_missing('output/output_stat') |
|
|
|
|
|
|
|
output_stat_file = os.path.join('output/', 'cliport_output/', 'cliport-training.txt') |
|
mkdir_if_missing('output/cliport_output/') |
|
file_handle = open(output_stat_file, 'a+') |
|
|
|
tasks_list = list(tasks.names.keys()) |
|
agents_list = list(agents.names.keys()) |
|
demos_list = [1, 5, 10, 20, 30, 50, 100, 200, 1000] |
|
|
|
results = {} |
|
for t in tasks_list: |
|
for a in agents_list: |
|
for d in demos_list: |
|
task_folder = f'{t}-{a}-n{d}-train' |
|
task_folder_path = os.path.join(exp_folder, task_folder, 'checkpoints') |
|
|
|
if os.path.exists(task_folder_path): |
|
print(f"train {task_folder_path}") |
|
|
|
jsons = [f for f in os.listdir(task_folder_path) if '.json' in f] |
|
for j in jsons: |
|
model_type = 'multi' if 'multi' in j else 'single' |
|
eval_type = 'val' if 'val' in j else 'test' |
|
|
|
with open(os.path.join(task_folder_path, j)) as f: |
|
res = json.load(f) |
|
|
|
results[f'{t}-{a}-n{d}-{model_type}-{eval_type}'] = res |
|
|
|
dt_string = datetime.datetime.now().strftime("%d_%m_%Y_%H:%M:%S") |
|
print_and_write(file_handle, f"==========================={dt_string}=========================\n") |
|
print_and_write(file_handle, f'Experiments folder: {exp_folder}\n') |
|
|
|
data = {'task': [], 'success': []} |
|
|
|
for eval_type in ['val', 'test']: |
|
print_and_write(file_handle, f'----- {eval_type.upper()} -----\n') |
|
for t in tasks_list: |
|
for a in agents_list: |
|
for d in demos_list: |
|
for model_type in ['single', 'multi']: |
|
eval_key = f'{t}-{a}-n{d}-{model_type}-{eval_type}' |
|
|
|
if eval_key in results: |
|
print_and_write(file_handle, f'{eval_key} {t} | Train Demos: {d}') |
|
res = results[eval_key] |
|
best_score, best_ckpt = max(zip([v['mean_reward'] for v in list(res.values())], res.keys())) |
|
|
|
|
|
print_and_write(file_handle, f'\t{best_score*100:1.1f} : {a} | {model_type}\n') |
|
data['task'].append(t) |
|
data['success'].append(best_score) |
|
|
|
data['task'].append("Average") |
|
data['success'].append(np.mean(data["success"])) |
|
|
|
|
|
|
|
dfs = [] |
|
suffix = "" |
|
run_num = 0 |
|
df = pd.DataFrame.from_dict(data) |
|
title = args.results + "_res" |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(16, 8)) |
|
sns_plot = sns.barplot( |
|
data=df, x="task", y="success", errorbar=("sd", 1), palette="deep" |
|
) |
|
|
|
|
|
for container in ax.containers: |
|
ax.bar_label(container, label_type="center", fontsize="x-large", fmt="%.2f") |
|
|
|
ax.set_xticklabels(['\n'.join(str(xlabel.get_text()).split("-")) for xlabel in ax.get_xticklabels()]) |
|
|
|
|
|
save_figure(f"{title}", title) |
|
|