#!/usr/bin/env python import re import itertools import math from itertools import chain import time # import seaborn import numpy as np import os from collections import OrderedDict, defaultdict import pandas as pd import matplotlib.pyplot as plt import sys from termcolor import cprint, colored from pathlib import Path import pickle eval_metric = "test_success_rates" # eval_metric = "exploration_bonus_mean" super_title = "" # super_title = "PPO - No exploration bonus" # super_title = "Count Based exploration bonus (Grid Search)" # super_title = "PPO + RND" # super_title = "PPO + RIDE" agg_title = "" color_dict = None eval_filename = None max_frames = 20_000_000 draw_legend = True per_seed = False study_eval = True plot_train = True plot_test = True plot_aggregated_test = False plot_only_aggregated_test = False train_inc_font = 3 xnbins = 4 ynbins = 3 steps_denom = 1e6 # Global vas for tracking and labeling data at load time. exp_idx = 0 label_parser_dict = None label_parser = lambda l, _, label_parser_dict: l # smooth_factor = 100 smooth_factor = 10 smooth_factor = 0 print("smooth factor:", smooth_factor) eval_smooth_factor = 1 leg_size = 30 def smooth(x_, n=50): if type(x_) == list: x_ = np.array(x_) return np.array([x_[max(i - n, 0):i + 1].mean() for i in range(len(x_))]) sort_test = False def sort_test_set(env_name): helps = [ "LanguageFeedback", "LanguageColor", "Pointing", "Emulation", ] problems = [ "Boxes", "Switches", "Generators", "Marble", "Doors", "Levers", ] env_names = [] for p in problems: for h in helps: env_names.append(h+p) env_names.extend([ "LeverDoorColl", "MarblePushColl", "MarblePassColl", "AppleStealing" ]) for i, en in enumerate(env_names): if en in env_name: return i raise ValueError(f"Test env {env_name} not known") subsample_step = 1 load_subsample_step = 1 x_lim = 0 max_x_lim = 17 max_x_lim = np.inf # x_lim = 100 summary_dict = {} summary_dict_colors = {} # default_colors = ["blue","orange","green","magenta", "brown", "red",'black',"grey",u'#ff7f0e', # "cyan", "pink",'purple', u'#1f77b4', # "darkorchid","sienna","lightpink", "indigo","mediumseagreen",'aqua', # 'deeppink','silver','khaki','goldenrod','y','y','y','y','y','y','y','y','y','y','y','y' ] + ['y']*50 default_colors_ = ["blue","orange","green","magenta", "brown", "red",'black',"grey",u'#ff7f0e', "cyan", "pink",'purple', u'#1f77b4', "darkorchid","sienna","lightpink", "indigo","mediumseagreen",'aqua', 'deeppink','silver','khaki','goldenrod'] * 100 def get_eval_data(logdir, eval_metric): eval_data = defaultdict(lambda :defaultdict(list)) for root, _, files in os.walk(logdir): for file in files: if 'testing_' in file: assert ".pkl" in file test_env_name = file.lstrip("testing_").rstrip(".pkl") try: with open(root+"/"+file, "rb") as f: seed_eval_data = pickle.load(f) except: print("Pickle not loaded: ", root+"/"+file) time.sleep(1) continue eval_data[test_env_name]["values"].append(seed_eval_data[eval_metric]) eval_data[test_env_name]["steps"].append(seed_eval_data["test_step_nb"]) # if 'log.csv' in files: # run_name = root[8:] # exp_name = None # # config = None # exp_idx += 1 # # # load progress data # try: # print(os.path.join(root, 'log.csv')) # exp_data = pd.read_csv(os.path.join(root, 'log.csv')) # except: # size = (Path(root) / 'log.csv').stat().st_size # if size == 0: # raise ValueError("CSV {} empty".format(os.path.join(root, 'log.csv'))) # else: # raise ValueError("CSV {} faulty".format(os.path.join(root, 'log.csv'))) # # exp_data = exp_data[::load_subsample_step] # data_dict = exp_data.to_dict("list") # # data_dict['config'] = config # nb_epochs = len(data_dict['frames']) # print('{} -> {}'.format(run_name, nb_epochs)) for test_env, seed_data in eval_data.items(): min_len_seed = min([len(s) for s in seed_data['steps']]) eval_data[test_env]["values"] = np.array([s[:min_len_seed] for s in eval_data[test_env]["values"]]) eval_data[test_env]["steps"] = np.array([s[:min_len_seed] for s in eval_data[test_env]["steps"]]) return eval_data def get_all_runs(logdir, load_subsample_step=1): """ Recursively look through logdir for output files produced by Assumes that any file "log.csv" is a valid hit. """ global exp_idx global units datasets = [] for root, _, files in os.walk(logdir): if 'log.csv' in files: if (Path(root) / 'log.csv').stat().st_size == 0: print("CSV {} empty".format(os.path.join(root, 'log.csv'))) continue run_name = root[8:] exp_name = None config = None exp_idx += 1 # load progress data try: exp_data = pd.read_csv(os.path.join(root, 'log.csv')) print("Loaded:", os.path.join(root, 'log.csv')) except: raise ValueError("CSV {} faulty".format(os.path.join(root, 'log.csv'))) exp_data = exp_data[::load_subsample_step] data_dict = exp_data.to_dict("list") data_dict['config'] = config nb_epochs = len(data_dict['frames']) if nb_epochs == 1: print(f'{run_name} -> {colored(f"nb_epochs {nb_epochs}", "red")}') else: print('{} -> nb_epochs {}'.format(run_name, nb_epochs)) datasets.append(data_dict) return datasets def get_datasets(rootdir, load_only="", load_subsample_step=1, ignore_patterns=("ignore"), require_patterns=()): _, models_list, _ = next(os.walk(rootdir)) for dir_name in models_list.copy(): # add "ignore" in a directory name to avoid loading its content for ignore_pattern in ignore_patterns: if ignore_pattern in dir_name or load_only not in dir_name: if dir_name in models_list: models_list.remove(dir_name) if len(require_patterns) > 0: if not any([require_pattern in dir_name for require_pattern in require_patterns]): if dir_name in models_list: models_list.remove(dir_name) for expe_name in list(labels.keys()): if expe_name not in models_list: del labels[expe_name] # setting per-model type colors for i, m_name in enumerate(models_list): for m_type, m_color in per_model_colors.items(): if m_type in m_name: colors[m_name] = m_color print("extracting data for {}...".format(m_name)) m_id = m_name models_saves[m_id] = OrderedDict() models_saves[m_id]['data'] = get_all_runs(rootdir+m_name, load_subsample_step=load_subsample_step) print("done") if m_name not in labels: labels[m_name] = m_name model_eval_data[m_id] = get_eval_data(logdir=rootdir+m_name, eval_metric=eval_metric) """ retrieve all experiences located in "data to vizu" folder """ labels = OrderedDict() per_model_colors = OrderedDict() # per_model_colors = OrderedDict([('ALP-GMM',u'#1f77b4'), # ('hmn','pink'), # ('ADR','black')]) # LOAD DATA models_saves = OrderedDict() colors = OrderedDict() model_eval_data = OrderedDict() static_lines = {} # get_datasets("storage/",load_only="RERUN_WizardGuide") # get_datasets("storage/",load_only="RERUN_WizardTwoGuides") try: load_pattern = eval(sys.argv[1]) except: load_pattern = sys.argv[1] ignore_patterns = ["_ignore_"] require_patterns = [ "_" ] # require_patterns = [ # "dummy_cs_jz_scaf_A_E_N_A_E", # "03-12_dummy_cs_jz_formats_AE", # ] # # def label_parser(label, figure_id, label_parser_dict=None): # if "single" in label: # ty = "single" # elif "group" in label: # ty = "group" # # if "asoc" in label: # return f"Asocial_pretrain({ty})" # # if "exp_soc" in label: # return f"Role_B_pretrain({ty})" # # return label # # # DUMMY FORMATS # require_patterns = [ # "03-12_dummy_cs_formats_CBL", # "dummy_cs_formats_CBL_N_rec_5" # "03-12_dummy_cs_jz_formats_", # "dummy_cs_jz_formats_N_rec_5" # ] # def label_parser(label, figure_id, label_parser_dict=None): # if "CBL" in label: # eb = "CBL" # else: # eb = "no_bonus" # # if "AE" in label: # label = f"AE_PPO_{eb}" # elif "E" in label: # label = f"E_PPO_{eb}" # elif "A" in label: # label = f"A_PPO_{eb}" # elif "N" in label: # label = f"N_PPO_{eb}" # # return label # # DUMMY CLASSIC # require_patterns = [ # "07-12_dummy_cs_NEW2_Pointing_sm_CB_very_small", # "dummy_cs_JA_Pointing_CB_sm", # "06-12_dummy_cs_NEW_Color_CBL", # "dummy_cs_JA_Color_CBL_new" # "07-12_dummy_cs_NEW2_Feedback_CBL", # "dummy_cs_JA_Feedback_CBL_new" # "08-12_dummy_cs_emulation_no_distr_rec_5_CB_exploration-bonus-type_cell_exploration-bonus-params__1_50", # "08-12_dummy_cs_emulation_no_distr_rec_5_CB", # "dummy_cs_RR_ft_NEW_single_CB_marble_pass_B_exp_soc", # "dummy_cs_RR_ft_NEW_single_CB_marble_pass_B_contr_asoc", # "dummy_cs_RR_ft_NEW_group_CB_marble_pass_A_exp_soc", # "dummy_cs_RR_ft_NEW_group_CB_marble_pass_A_contr_asoc" # "03-12_dummy_cs_jz_formats_A", # "03-12_dummy_cs_jz_formats_E", # "03-12_dummy_cs_jz_formats_AE", # "dummy_cs_jz_formats_N_rec_5" # "03-12_dummy_cs_formats_CBL_A", # "03-12_dummy_cs_formats_CBL_E", # "03-12_dummy_cs_formats_CBL_AE", # "dummy_cs_formats_CBL_N_rec_5" # "03-12_dummy_cs_jz_formats_AE", # "dummy_cs_jz_scaf_A_E_N_A_E_full-AEfull", # "dummy_cs_jz_scaf_A_E_N_A_E_scaf_full-AEfull", # ] # def label_parser(label, figure_id, label_parser_dict=None): # label = label.replace("07-12_dummy_cs_NEW2_Pointing_sm_CB_very_small", "PPO_CB") # label = label.replace("dummy_cs_JA_Pointing_CB_sm", "JA_PPO_CB") # # label = label.replace("06-12_dummy_cs_NEW_Color_CBL", "PPO_CBL") # label = label.replace("dummy_cs_JA_Color_CBL_new", "JA_PPO_CBL") # # label = label.replace("07-12_dummy_cs_NEW2_Feedback_CBL", "PPO_CBL") # label = label.replace("dummy_cs_JA_Feedback_CBL_new", "JA_PPO_CBL") # # label = label.replace( # "08-12_dummy_cs_emulation_no_distr_rec_5_CB_exploration-bonus-type_cell_exploration-bonus-params__1_50", # "PPO_CB_1") # label = label.replace( # "08-12_dummy_cs_emulation_no_distr_rec_5_CB_exploration-bonus-type_cell_exploration-bonus-params__1_50", # "PPO_CB_1") # # label = label.replace("dummy_cs_RR_ft_NEW_single_CB_marble_pass_B_exp_soc", "PPO_CB_role_B_single") # label = label.replace("dummy_cs_RR_ft_NEW_single_CB_marble_pass_B_contr_asoc", "PPO_CB_asoc_single") # # label = label.replace("dummy_cs_RR_ft_NEW_group_CB_marble_pass_A_exp_soc", "PPO_CB_role_B_group") # label = label.replace("dummy_cs_RR_ft_NEW_group_CB_marble_pass_A_contr_asoc", "PPO_CB_asoc_group") # # label = label.replace( # "03-12_dummy_cs_formats_CBL_A_rec_5_env_SocialAI-ALangFeedbackTrainFormatsCSParamEnv-v1_recurrence_5_test-set-name_AFormatsTestSet_exploration-bonus-type_lang", # "PPO_CBL_Ask") # label = label.replace( # "03-12_dummy_cs_formats_CBL_E_rec_5_env_SocialAI-ELangFeedbackTrainFormatsCSParamEnv-v1_recurrence_5_test-set-name_EFormatsTestSet_exploration-bonus-type_lang", # "PPO_CBL_Eye_contact") # label = label.replace( # "03-12_dummy_cs_formats_CBL_AE_rec_5_env_SocialAI-AELangFeedbackTrainFormatsCSParamEnv-v1_recurrence_5_test-set-name_AEFormatsTestSet_exploration-bonus-type_lang", # "PPO_CBL_Ask_Eye_contact") # label = label.replace("dummy_cs_formats_CBL_N_rec_5", "PPO_CBL_No") # # label = label.replace( # "03-12_dummy_cs_jz_formats_E_rec_5_env_SocialAI-ELangFeedbackTrainFormatsCSParamEnv-v1_recurrence_5_test-set-name_EFormatsTestSet", # "PPO_no_bonus_Eye_contact") # label = label.replace( # "03-12_dummy_cs_jz_formats_A_rec_5_env_SocialAI-ALangFeedbackTrainFormatsCSParamEnv-v1_recurrence_5_test-set-name_AFormatsTestSet", # "PPO_no_bonus_Ask") # label = label.replace( # "03-12_dummy_cs_jz_formats_AE_rec_5_env_SocialAI-AELangFeedbackTrainFormatsCSParamEnv-v1_recurrence_5_test-set-name_AEFormatsTestSet", # "PPO_no_bonus_Ask_Eye_contact") # label = label.replace("dummy_cs_jz_formats_N_rec_5", "PPO_no_bonus_No") # # label = label.replace("03-12_dummy_cs_jz_formats_AE", "PPO_no_bonus_no_scaf") # label = label.replace("dummy_cs_jz_scaf_A_E_N_A_E_full-AEfull", "PPO_no_bonus_scaf_4") # label = label.replace("dummy_cs_jz_scaf_A_E_N_A_E_scaf_full-AEfull", "PPO_no_bonus_scaf_8") # # return label # Final case studies require_patterns = [ "_", # pointing # "04-01_Pointing_CB_heldout_doors", # # role reversal # "03-01_RR_ft_single_CB_marble_pass_A_asoc_contr", # "03-01_RR_ft_single_CB_marble_pass_A_soc_exp", # "05-01_RR_ft_group_50M_CB_marble_pass_A_asoc_contr", # "05-01_RR_ft_group_50M_CB_marble_pass_A_soc_exp", # scaffolding # "05-01_scaffolding_50M_no", # "05-01_scaffolding_50M_acl_4_acl-type_intro_seq", # "05-01_scaffolding_50M_acl_8_acl-type_intro_seq_scaf", ] def label_parser(label, figure_id, label_parser_dict=None): label = label.replace("04-01_Pointing_CB_heldout_doors", "PPO_CB") label = label.replace("05-01_scaffolding_50M_no_acl", "PPO_no_scaf") label = label.replace("05-01_scaffolding_50M_acl_4_acl-type_intro_seq", "PPO_scaf_4") label = label.replace("05-01_scaffolding_50M_acl_8_acl-type_intro_seq_scaf", "PPO_scaf_8") label = label.replace("03-01_RR_ft_single_CB_marble_pass_A_soc_exp", "PPO_CB_role_B") label = label.replace("03-01_RR_ft_single_CB_marble_pass_A_asoc_contr", "PPO_CB_asocial") label = label.replace("05-01_RR_ft_group_50M_CB_marble_pass_A_soc_exp", "PPO_CB_role_B") label = label.replace("05-01_RR_ft_group_50M_CB_marble_pass_A_asoc_contr", "PPO_CB_asocial") return label color_dict = { # JA # "JA_PPO_CBL": "blue", # "PPO_CBL": "orange", # RR group # "PPO_CB_role_B_group": "orange", # "PPO_CB_asoc_group": "blue" # formats No # "PPO_no_bonus_No": "blue", # "PPO_no_bonus_Eye_contact": "magenta", # "PPO_no_bonus_Ask": "orange", # "PPO_no_bonus_Ask_Eye_contact": "green" # formats CBL # "PPO_CBL_No": "blue", # "PPO_CBL_Eye_contact": "magenta", # "PPO_CBL_Ask": "orange", # "PPO_CBL_Ask_Eye_contact": "green" } # # POINTING_GENERALIZATION (DUMMY) # require_patterns = [ # "29-10_SAI_Pointing_CS_PPO_CB_", # "29-10_SAI_LangColor_CS_PPO_CB_" # ] # # color_dict = { # "dummy_cs_JA_Feedback_CBL_new": "blue", # "dummy_cs_Feedback_CBL": "orange", # } # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("Pointing_CS_PPO_CB", "PPO_CB_train(DUMMY)") # label=label.replace("LangColor_CS_PPO_CB", "PPO_CB_test(DUMMY)") # return label # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Pointing_gen_eval.png" # # FEEDBACK GENERALIZATION (DUMMY) # require_patterns = [ # "29-10_SAI_LangFeedback_CS_PPO_CBL_", # "29-10_SAI_LangColor_CS_PPO_CB_" # ] # # color_dict = { # "PPO_CBL_train(DUMMY)": "blue", # "PPO_CBL_test(DUMMY)": "maroon", # } # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("LangFeedback_CS_PPO_CBL", "PPO_CBL_train(DUMMY)") # label=label.replace("LangColor_CS_PPO_CB", "PPO_CBL_test(DUMMY)") # return label # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Feedback_gen_eval.png" # # COLOR GENERALIZATION (DUMMY) # require_patterns = [ # "29-10_SAI_LangColor_CS_PPO_CBL_", # "29-10_SAI_LangColor_CS_PPO_CB_" # ] # # color_dict = { # "PPO_CBL_train(DUMMY)": "blue", # "PPO_CBL_test(DUMMY)": "maroon", # } # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("LangColor_CS_PPO_CBL", "PPO_CBL_train(DUMMY)") # label=label.replace("LangColor_CS_PPO_CB", "PPO_CBL_test(DUMMY)") # return label # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Color_gen_eval.png" # # POINTING - PILOT # require_patterns = [ # "29-10_SAI_Pointing_CS_PPO_", # ] # # color_dict = { # "PPO_RIDE": "orange", # "PPO_RND": "magenta", # "PPO_no": "maroon", # "PPO_CBL": "green", # "PPO_CB": "blue", # } # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("Pointing_CS_", "") # return label # # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Pointing_eval.png" # LANGCOLOR - 7 Colors - PILOT # require_patterns = [ # "29-10_SAI_LangColor_CS_PPO_", # ] # # color_dict = { # "PPO_RIDE": "orange", # "PPO_RND": "magenta", # "PPO_no": "maroon", # "PPO_CBL": "green", # "PPO_CB": "blue", # } # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("LangColor_CS_", "") # return label # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Color_eval.png" # # LangColor - CBL - 3 5 7 # require_patterns = [ # "02-11_SAI_LangColor_CS_5C_PPO_CBL", # "02-11_SAI_LangColor_CS_3C_PPO_CBL", # "29-10_SAI_LangColor_CS_PPO_CBL" # ] # RND RIDE reference : RIDE > RND > no # require_patterns = [ # "24-08_new_ref", # ] # # # LANG FEEDBACK # require_patterns = [ # "24-10_SAI_LangFeedback_CS_PPO_", # "29-10_SAI_LangFeedback_CS_PPO_", # ] # color_dict = { # "PPO_RIDE": "orange", # "PPO_RND": "magenta", # "PPO_no": "maroon", # "PPO_CBL": "green", # "PPO_CB": "blue", # } # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("LangFeedback_CS_", "") # return label # # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Feedback_eval.png" # # # ROLE REVERSAL - group (DUMMY) # require_patterns = [ # "24-10_SAI_LangFeedback_CS_PPO_CB_", # "29-10_SAI_LangFeedback_CS_PPO_CBL_", # ] # color_dict = { # "PPO_CB_experimental": "green", # "PPO_CB_control": "blue", # } # color_dict=None # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("LangFeedback_CS_", "") # # label=label.replace("PPO_CB", "PPO_CB_control") # label=label.replace("controlL", "experimental") # # return label # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/RR_dummy_group.png" # # ROLE REVERSAL - single (DUMMY) # require_patterns = [ # "24-10_SAI_LangFeedback_CS_PPO_CB_", # "24-10_SAI_LangFeedback_CS_PPO_no_", # ] # color_dict = { # "PPO_CB_experimental": "green", # "PPO_CB_control": "blue", # } # color_dict=None # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("LangFeedback_CS_", "") # # label=label.replace("PPO_CB", "PPO_CB_control") # label=label.replace("PPO_no", "PPO_CB_experimental") # # return label # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/RR_dummy_single.png" # # IMITATION train (DUMMY) # require_patterns = [ # "29-10_SAI_LangFeedback_CS_PPO_CBL_", # "29-10_SAI_Pointing_CS_PPO_RIDE", # ] # # color_dict = { # "PPO_CB_no_distr(DUMMY)": "magenta", # "PPO_CB_distr(DUMMY)": "orange", # } # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("LangFeedback_CS_PPO_CBL", "PPO_CB_no_distr(DUMMY)") # label=label.replace("Pointing_CS_PPO_RIDE", "PPO_CB_distr(DUMMY)") # return label # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Imitation_train.png" # # IMITATION test (DUMMY) # require_patterns = [ # "29-10_SAI_LangFeedback_CS_PPO_CBL_", # "29-10_SAI_Pointing_CS_PPO_RIDE", # ] # # color_dict = { # "PPO_CB_no_distr(DUMMY)": "magenta", # "PPO_CB_distr(DUMMY)": "orange", # } # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("LangFeedback_CS_PPO_CBL", "PPO_CB_no_distr(DUMMY)") # label=label.replace("Pointing_CS_PPO_RIDE", "PPO_CB_distr(DUMMY)") # return label # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Imitation_test.png" # JA_POINTING # require_patterns = [ # "29-10_SAI_Pointing_CS_PPO_CB_", # "04-11_SAI_JA_Pointing_CS_PPO_CB_less", # less reward # ] # color_dict = { # "JA_Pointing_PPO_CB": "orange", # "Pointing_PPO_CB": "blue", # } # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("_CS_", "_") # label=label.replace("_less_", "") # return label # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/JA_Pointing_eval.png" # # JA_COLORS (JA, no) x (3,5,7) # max_x_lim = 17 # require_patterns = [ # # "02-11_SAI_JA_LangColor", # max_x_lim = 17 # "02-11_SAI_JA_LangColor_CS_3C", # max_x_lim = 17 # # "02-11_SAI_LangColor_CS_5C_PPO_CBL", # max_x_lim = 17 # "02-11_SAI_LangColor_CS_3C_PPO_CBL", # # "29-10_SAI_LangColor_CS_PPO_CBL" # ] # color_dict = { # "JA_LangColor_PPO_CBL": "orange", # "LangColor_PPO_CBL": "blue", # } # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("_CS_", "_") # label=label.replace("_3C_", "_") # return label # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/JA_Color_eval.png" # JA_FEEDBACK -> max_xlim=17 # max_x_lim = 17 # require_patterns = [ # "02-11_SAI_JA_LangFeedback_CS_PPO_CBL_", # "29-10_SAI_LangFeedback_CS_PPO_CBL_", # "dummy_cs_F", # "dummy_cs_JA_F" # ] # color_dict = { # "JA_LangFeedback_PPO_CBL": "orange", # "LangFeedback_PPO_CBL": "blue", # } # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("_CS_", "_") # return label # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/JA_Feedback_eval.png" # # Formats CBL # require_patterns = [ # "03-11_SAI_LangFeedback_CS_F_NO_PPO_CBL_env_SocialAI", # "29-10_SAI_LangFeedback_CS_PPO_CBL_env_SocialAI", # "03-11_SAI_LangFeedback_CS_F_ASK_PPO_CBL_env_SocialAI", # "03-11_SAI_LangFeedback_CS_F_ASK_EYE_PPO_CBL_env_SocialAI", # ] # color_dict = { # "LangFeedback_Eye_PPO_CBL": "blue", # "LangFeedback_Ask_PPO_CBL": "orange", # "LangFeedback_NO_PPO_CBL": "green", # "LangFeedback_AskEye_PPO_CBL": "magenta", # } # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("_CS_", "_") # label=label.replace("_F_", "_") # # label=label.replace("LangFeedback_PPO", "LangFeedback_EYE_PPO") # # label=label.replace("EYE", "Eye") # label=label.replace("No", "No") # label=label.replace("ASK", "Ask") # label=label.replace("Ask_Eye", "AskEye") # return label # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Formats_CBL_eval.png" # # Formats NO # require_patterns = [ # "24-10_SAI_LangFeedback_CS_PPO_no", # EYE # "04-11_SAI_LangFeedback_CS_F_NO_PPO_NO_env_SocialAI", # "04-11_SAI_LangFeedback_CS_F_ASK_PPO_NO_env_SocialAI", # "04-11_SAI_LangFeedback_CS_F_ASK_EYE_PPO_NO_env_SocialAI", # ] # # color_dict = { # "LangFeedback_Eye_PPO_no": "blue", # "LangFeedback_Ask_PPO_no": "orange", # "LangFeedback_NO_PPO_no": "green", # "LangFeedback_AskEye_PPO_no": "magenta", # } # # def label_parser(label, figure_id, label_parser_dict=None): # label = label.split("_env_")[0].split("SAI_")[1] # label=label.replace("_CS_", "_") # label=label.replace("_F_", "_") # # # label=label.replace("LangFeedback_PPO", "LangFeedback_EYE_PPO") # label=label.replace("PPO_NO", "PPO_no") # # label=label.replace("EYE", "Eye") # label=label.replace("No", "No") # label=label.replace("ASK", "Ask") # label=label.replace("Ask_Eye", "AskEye") # return label # # eval_filename = f"/home/flowers/Documents/projects/embodied_acting_and_speaking/case_studies_figures/Formats_no_eval.png" # # require_patterns = [ # "11-07_bAI_cb_GS_param_tanh_env_SocialAI-SocialAIParamEnv-v1_exploration-bonus-type_cell_exploration-bonus-params__2_50_exploration-bonus-tanh_0.6", # # "04-11_SAI_ImitationDistr_CS_PPO_CB_small_env_SocialAI-EEmulationDistrInformationSeekingParamEnv-v1_recurrence_10", # # "04-11_SAI_ImitationDistr_CS_PPO_CB_small_env_SocialAI-EEmulationDistrInformationSeekingParamEnv-v1_recurrence_10", # "03-11_SAI_ImitationDistr_CS_PPO_CB_env_SocialAI-EEmulationDistrInformationSeekingParamEnv-v1_recurrence_10", # # "04-11_SAI_ImitationNoDistr_CS_PPO_CB_small_env_SocialAI-EEmulationNoDistrInformationSeekingParamEnv-v1_recurrence_10", # ] # require_patterns = [ # "02-11_SAI_LangColor_CS_3C_PPO_CBL", # "02-11_SAI_JA_LangColor_CS_3C_PPO_CBL", # ] # at least one of those # all of those include_patterns = [ "_" ] #include_patterns = ["rec_5"] if eval_filename: # saving fontsize = 40 legend_fontsize = 30 linewidth = 10 else: fontsize = 5 legend_fontsize = 5 linewidth = 1 fontsize = 5 legend_fontsize = 5 linewidth = 1 title_fontsize = int(fontsize*1.2) storage_dir = "storage/" if load_pattern.startswith(storage_dir): load_pattern = load_pattern[len(storage_dir):] if load_pattern.startswith("./storage/"): load_pattern = load_pattern[len("./storage/"):] get_datasets(storage_dir, str(load_pattern), load_subsample_step=load_subsample_step, ignore_patterns=ignore_patterns, require_patterns=require_patterns) label_parser_dict = { # "PPO_CB": "PPO_CB", # "02-06_AppleStealing_experiments_cb_bonus_angle_occ_env_SocialAI-OthersPerceptionInferenceParamEnv-v1_exploration-bonus-type_cell": "NPC_visible", } env_type = str(load_pattern) fig_type = "test" try: top_n = int(sys.argv[2]) except: top_n = 8 to_remove = [] for tr_ in to_remove: if tr_ in models_saves: del models_saves[tr_] print("Loaded:") print("\n".join(list(models_saves.keys()))) #### get_datasets("storage/", "RERUN_WizardGuide_lang64_nameless") #### get_datasets("storage/", "RERUN_WizardTwoGuides_lang64_nameless") if per_model_colors: # order runs for legend order as in per_models_colors, with corresponding colors ordered_labels = OrderedDict() for teacher_type in per_model_colors.keys(): for k,v in labels.items(): if teacher_type in k: ordered_labels[k] = v labels = ordered_labels else: print('not using per_model_color') for k in models_saves.keys(): labels[k] = k def plot_with_shade_seed(subplot_nb, ax, x, y, err, color, shade_color, label, y_min=None, y_max=None, legend=False, leg_size=30, leg_loc='best', title=None, ylim=[0,100], xlim=[0,40], leg_args={}, leg_linewidth=13.0, linewidth=10.0, labelsize=20, filename=None, zorder=None, xlabel='perf', ylabel='Env steps'): plt.rcParams.update({'font.size': 15}) plt.rcParams['axes.xmargin'] = 0 plt.rcParams['axes.ymargin'] = 0 ax.locator_params(axis='x', nbins=3) ax.locator_params(axis='y', nbins=3) ax.tick_params(axis='both', which='major', labelsize=labelsize) x = x[:len(y)] # ax.scatter(x, y, color=color, linewidth=linewidth, zorder=zorder) ax.plot(x, y, color=color, label=label, linewidth=linewidth, zorder=zorder) if err is not None: ax.fill_between(x, y-err, y+err, color=shade_color, alpha=0.2) if legend: leg = ax.legend(loc=leg_loc, **leg_args) #34 for legobj in leg.legendHandles: legobj.set_linewidth(leg_linewidth) ax.set_xlabel(xlabel, fontsize=fontsize) if subplot_nb == 0: ax.set_ylabel(ylabel, fontsize=fontsize, labelpad=4) ax.set_xlim(xmin=xlim[0],xmax=xlim[1]) ax.set_ylim(bottom=ylim[0],top=ylim[1]) if title: ax.set_title(title, fontsize=fontsize) # if filename is not None: # f.savefig(filename) # Plot utils def plot_with_shade_grg(subplot_nb, ax, x, y, err, color, shade_color, label, legend=False, leg_loc='best', title=None, ylim=[0, 100], xlim=[0, 40], leg_args={}, leg_linewidth=13.0, linewidth=10.0, labelsize=20, fontsize=20, title_fontsize=30, zorder=None, xlabel='Perf', ylabel='Env steps', linestyle="-", xnbins=3, ynbins=3, filename=None): #plt.rcParams.update({'font.size': 15}) ax.locator_params(axis='x', nbins=xnbins) ax.locator_params(axis='y', nbins=ynbins) ax.tick_params(axis='y', which='both', labelsize=labelsize) ax.tick_params(axis='x', which='both', labelsize=labelsize*0.8) # ax.tick_params(axis='both', which='both', labelsize="small") # ax.scatter(x, y, color=color,linewidth=linewidth,zorder=zorder, linestyle=linestyle) ax.plot(x, y, color=color, label=label, linewidth=linewidth, zorder=zorder, linestyle=linestyle) ax.fill_between(x, y-err, y+err, color=shade_color, alpha=0.2) if legend: leg = ax.legend(loc=leg_loc, **leg_args) # 34 for legobj in leg.legendHandles: legobj.set_linewidth(leg_linewidth) ax.set_xlabel(xlabel, fontsize=fontsize) if subplot_nb == 0: ax.set_ylabel(ylabel, fontsize=fontsize, labelpad=2) ax.set_xlim(xmin=xlim[0], xmax=xlim[1]) ax.set_ylim(bottom=ylim[0], top=ylim[1]) if title: ax.set_title(title, fontsize=title_fontsize) # if filename is not None: # f.savefig(filename) # Metric plot # metric = 'success_rate_mean' # metric = 'mission_string_observed_mean' # metric = 'extrinsic_return_mean' # metric = 'extrinsic_return_max' # metric = "rreturn_mean" # metric = 'rreturn_max' # metric = 'FPS' # metric = 'duration' # metric = 'intrinsic_reward_perf2_' # metric = 'NPC_intro' metrics = [ 'success_rate_mean', # 'FPS', # 'extrinsic_return_mean', # 'exploration_bonus_mean', 'NPC_intro', # 'curriculum_param_mean', # 'curriculum_max_success_rate_mean', # 'rreturn_mean' ] # f, ax = plt.subplots(1, len(metrics), figsize=(15.0, 9.0)) f, ax = plt.subplots(1, len(metrics), figsize=(9.0, 9.0)) # f, ax = plt.subplots(1, len(metrics), figsize=(20.0, 20.0)) # f, ax = plt.subplots(1, 1, figsize=(5.0, 3.0)) if len(metrics) == 1: ax = [ax] max_y = -np.inf min_y = np.inf # hardcoded min_y, max_y = 0.0, 1.0 max_steps = 0 exclude_patterns = [] # def label_parser(label, figure_id, label_parser_dict=None): # # label = label.split("_env_")[0].split("SAI_")[1] # # # # Pointing # # label=label.replace("Pointing_CS_", "") # # # Feedback # label=label.replace("LangFeedback_CS_", "") # # # # label=label.replace("CS_PPO", "7COL_PPO") # # label=label.replace("CS_3C_PPO", "3COL_PPO") # # label=label.replace("CS_5C_PPO", "5COL_PPO") # # # label=label.replace("CS_PPO", "Eye_contact_PPO") # # label=label.replace("CS_F_ASK_PPO", "Ask_PPO") # # label=label.replace("CS_F_NO_PPO", "NO_PPO") # # label=label.replace("CS_F_ASK_EYE_PPO", "Ask_Eye_contact_PPO") # # # # label=label.replace("PPO_no", "PPO_no_bonus") # # label=label.replace("PPO_NO", "PPO_no_bonus") # # if label_parser_dict: # if sum([1 for k, v in label_parser_dict.items() if k in label]) != 1: # if label in label_parser_dict: # # see if there is an exact match # return label_parser_dict[label] # else: # print("ERROR multiple curves match a lable and there is no exact match for {}".format(label)) # exit() # # for k, v in label_parser_dict.items(): # if k in label: return v # # else: # # return label.split("_env_")[1] # if figure_id not in [1, 2, 3, 4]: # return label # else: # # default # pass # # return label for metric_i, metric in enumerate(metrics): min_y, max_y = 0.0, 1.0 default_colors = default_colors_.copy() for model_i, m_id in enumerate(models_saves.keys()): #excluding some experiments if any([ex_pat in m_id for ex_pat in exclude_patterns]): continue if len(include_patterns) > 0: if not any([in_pat in m_id for in_pat in include_patterns]): continue runs_data = models_saves[m_id]['data'] ys = [] if runs_data[0]['frames'][1] == 'frames': runs_data[0]['frames'] = list(filter(('frames').__ne__, runs_data[0]['frames'])) ########################################### if per_seed: min_len = None else: # determine minimal run length across seeds lens = [len(run['frames']) for run in runs_data if len(run['frames'])] minimum = sorted(lens)[-min(top_n, len(lens))] min_len = np.min([len(run['frames']) for run in runs_data if len(run['frames']) >= minimum]) # keep only top k runs_data = [run for run in runs_data if len(run['frames']) >= minimum] # min_len = np.min([len(run['frames']) for run in runs_data if len(run['frames']) > 10]) # compute env steps (x axis) longest_id = np.argmax([len(rd['frames']) for rd in runs_data]) steps = np.array(runs_data[longest_id]['frames'], dtype=np.int) / steps_denom steps = steps[:min_len] for run in runs_data: if metric not in run: # succes_rate_mean <==> bin_extrinsic_return_mean if metric == 'success_rate_mean': metric_ = "bin_extrinsic_return_mean" if metric_ not in run: raise ValueError("Neither {} or {} is present: {} Possible metrics: {}. ".format(metric, metric_, list(run.keys()))) data = run[metric_] else: raise ValueError("Unknown metric: {} Possible metrics: {}. ".format(metric, list(run.keys()))) else: data = run[metric] if data[1] == metric: data = np.array(list(filter((metric).__ne__, data)), dtype=np.float16) ########################################### if per_seed: ys.append(data) else: if len(data) >= min_len: if len(data) > min_len: print("run has too many {} datapoints ({}). Discarding {}".format(m_id, len(data), len(data)-min_len)) data = data[0:min_len] ys.append(data) else: raise ValueError("How can data be < min_len if it was capped above") ys_same_len = ys # computes stats n_seeds = len(ys_same_len) if per_seed: sems = np.array(ys_same_len) stds = np.array(ys_same_len) means = np.array(ys_same_len) color = default_colors[model_i] else: sems = np.std(ys_same_len, axis=0)/np.sqrt(len(ys_same_len)) # sem stds = np.std(ys_same_len, axis=0) # std means = np.mean(ys_same_len, axis=0) color = default_colors[model_i] # per-metric adjustments ylabel = metric ylabel = { "success_rate_mean" : "Success rate", "exploration_bonus_mean": "Exploration bonus", "NPC_intro": "Successful introduction (%)", }.get(ylabel, ylabel) if metric == 'duration': ylabel = "time (hours)" means = means / 3600 sems = sems / 3600 stds = stds / 3600 if per_seed: #plot x y bounds curr_max_y = np.max(np.max(means)) curr_min_y = np.min(np.min(means)) curr_max_steps = np.max(np.max(steps)) else: # plot x y bounds curr_max_y = np.max(means+stds) curr_min_y = np.min(means-stds) curr_max_steps = np.max(steps) if curr_max_y > max_y: max_y = curr_max_y if curr_min_y < min_y: min_y = curr_min_y if curr_max_steps > max_steps: max_steps = curr_max_steps if subsample_step: steps = steps[0::subsample_step] means = means[0::subsample_step] stds = stds[0::subsample_step] sems = sems[0::subsample_step] ys_same_len = [y[0::subsample_step] for y in ys_same_len] # display seeds separtely if per_seed: for s_i, seed_ys in enumerate(ys_same_len): seed_c = default_colors[model_i+s_i] # label = m_id#+"(s:{})".format(s_i) label = str(s_i) seed_ys = smooth(seed_ys, smooth_factor) plot_with_shade_seed(0, ax[metric_i], steps, seed_ys, None, seed_c, seed_c, label, legend=draw_legend, xlim=[0, max_steps], ylim=[min_y, max_y], leg_size=leg_size, xlabel=f"Env steps (1e6)", ylabel=ylabel, linewidth=linewidth, labelsize=fontsize, # fontsize=fontsize, ) summary_dict[s_i] = seed_ys[-1] summary_dict_colors[s_i] = seed_c else: label = label_parser(m_id, load_pattern, label_parser_dict=label_parser_dict) if color_dict: color = color_dict[label] else: color = default_colors[model_i] label = label+"({})".format(n_seeds) if smooth_factor: means = smooth(means, smooth_factor) stds = smooth(stds, smooth_factor) x_lim = max(steps[-1], x_lim) x_lim = min(max_x_lim, x_lim) leg_args = { 'fontsize': legend_fontsize } plot_with_shade_grg( 0, ax[metric_i], steps, means, stds, color, color, label, legend=draw_legend and metric_i == 0, xlim=[0, x_lim], ylim=[0, max_y], xlabel=f"Env steps (1e6)", ylabel=ylabel, title=None, labelsize=fontsize*train_inc_font, fontsize=fontsize*train_inc_font, title_fontsize=title_fontsize, linewidth=linewidth, leg_linewidth=5, leg_args=leg_args, xnbins=xnbins, ynbins=ynbins, ) summary_dict[label] = means[-1] summary_dict_colors[label] = color if len(summary_dict) == 0: raise ValueError(f"No experiments found for {load_pattern}.") # print summary best = max(summary_dict.values()) pc = 0.3 n = int(len(summary_dict)*pc) print("top n: ", n) top_pc = sorted(summary_dict.values())[-n:] bottom_pc = sorted(summary_dict.values())[:n] print("legend:") cprint("\tbest", "green") cprint("\ttop {} %".format(pc), "blue") cprint("\tbottom {} %".format(pc), "red") print("\tothers") print() for l, p in sorted(summary_dict.items(), key=lambda kv: kv[1]): c = summary_dict_colors[l] if p == best: cprint("label: {} ({})".format(l, c), "green") cprint("\t {}:{}".format(metric, p), "green") elif p in top_pc: cprint("label: {} ({})".format(l, c), "blue") cprint("\t {}:{}".format(metric, p), "blue") elif p in bottom_pc: cprint("label: {} ({})".format(l, c), "red") cprint("\t {}:{}".format(metric, p), "red") else: print("label: {} ({})".format(l, c)) print("\t {}:{}".format(metric, p)) for label, (mean, std, color) in static_lines.items(): plot_with_shade_grg( 0, ax[metric_i], steps, np.array([mean]*len(steps)), np.array([std]*len(steps)), color, color, label, legend=True, xlim=[0, x_lim], ylim=[0, 1.0], xlabel=f"Env steps (1e6)", ylabel=ylabel, linestyle=":", leg_args=leg_args, fontsize=fontsize, title_fontsize=title_fontsize, xnbins=xnbins, ynbins=ynbins, ) # plt.tight_layout() # f.savefig('graphics/{}_{}_results.svg'.format(str(figure_id, metric))) # f.savefig('graphics/{}_{}_results.png'.format(str(figure_id, metric))) cprint("Ignore pattern: {}".format(ignore_patterns), "blue") if plot_train: plt.tight_layout() # plt.subplots_adjust(hspace=1.5, wspace=0.5, left=0.1, right=0.9, bottom=0.1, top=0.85) plt.subplots_adjust(hspace=1.5, wspace=0.5, left=0.1, right=0.9, bottom=0.1, top=0.85) plt.suptitle(super_title) plt.show() plt.close() curr_max_y = 0 x_lim = 0 max_y = -np.inf min_y = np.inf # hardcoded min_y, max_y = 0.0, 1.0 grid = True draw_eval_legend = True if study_eval: print("Evaluation") # evaluation sets number_of_eval_envs = max(list([len(v.keys()) for v in model_eval_data.values()])) if plot_aggregated_test: number_of_eval_envs += 1 if number_of_eval_envs == 0: print("No eval envs") exit() if plot_only_aggregated_test: f, ax = plt.subplots(1, 1, figsize=(9.0, 9.0)) else: if grid: # grid subplot_y = math.ceil(math.sqrt(number_of_eval_envs)) subplot_x = math.ceil(number_of_eval_envs / subplot_y) # from IPython import embed; embed() while subplot_x % 1 != 0: subplot_y -= 1 subplot_x = number_of_eval_envs / subplot_y if subplot_x == 1: subplot_y = math.ceil(math.sqrt(number_of_eval_envs)) subplot_x = math.floor(math.sqrt(number_of_eval_envs)) subplot_y = int(subplot_y) subplot_x = int(subplot_x) assert subplot_y * subplot_x >= number_of_eval_envs f, ax_ = plt.subplots(subplot_y, subplot_x, figsize=(6.0, 6.0), sharey=False) #, sharex=True, sharey=True) if subplot_y != 1: ax = list(chain.from_iterable(ax_)) else: ax=ax_ else: # flat f, ax = plt.subplots(1, number_of_eval_envs, figsize=(15.0, 9.0)) #), sharey=True, sharex=True) if number_of_eval_envs == 1: ax = [ax] default_colors = default_colors_.copy() test_summary_dict = defaultdict(dict) test_summary_dict_colors = defaultdict(dict) for model_i, m_id in enumerate(model_eval_data.keys()): # excluding some experiments if any([ex_pat in m_id for ex_pat in exclude_patterns]): continue if len(include_patterns) > 0: if not any([in_pat in m_id for in_pat in include_patterns]): continue # computes stats if sort_test: test_envs_sorted = enumerate(sorted(model_eval_data[m_id].items(), key=lambda kv: sort_test_set(kv[0]))) else: test_envs_sorted = enumerate(model_eval_data[m_id].items()) if plot_aggregated_test: agg_means = [] for env_i, (test_env, env_data) in test_envs_sorted: ys_same_len = env_data["values"] steps = env_data["steps"].mean(0) / steps_denom n_seeds = len(ys_same_len) if per_seed: sems = np.array(ys_same_len) stds = np.array(ys_same_len) means = np.array(ys_same_len) color = default_colors[model_i] else: sems = np.std(ys_same_len, axis=0) / np.sqrt(len(ys_same_len)) # sem stds = np.std(ys_same_len, axis=0) # std means = np.mean(ys_same_len, axis=0) color = default_colors[model_i] # per-metric adjusments if per_seed: # plot x y bounds curr_max_y = np.max(np.max(means)) curr_min_y = np.min(np.min(means)) curr_max_steps = np.max(np.max(steps)) else: # plot x y bounds curr_max_y = np.max(means + stds) curr_min_y = np.min(means - stds) curr_max_steps = np.max(steps) if plot_aggregated_test: agg_means.append(means) if curr_max_y > max_y: max_y = curr_max_y if curr_min_y < min_y: min_y = curr_min_y x_lim = max(steps[-1], x_lim) x_lim = min(max_x_lim, x_lim) eval_metric_name = { "test_success_rates": "Success rate", 'exploration_bonus_mean': "Exploration bonus", }.get(eval_metric, eval_metric) test_env_name = test_env.replace("Env", "").replace("Test", "") env_types = ["InformationSeeking", "Collaboration", "PerspectiveTaking"] for env_type in env_types: if env_type in test_env_name: test_env_name = test_env_name.replace(env_type, "") test_env_name += f"\n({env_type})" if grid: ylabel = eval_metric_name title = test_env_name else: # flat ylabel = test_env_name title = eval_metric_name leg_args = { 'fontsize': legend_fontsize // 1 } if per_seed: for s_i, seed_ys in enumerate(ys_same_len): seed_c = default_colors[model_i + s_i] # label = m_id#+"(s:{})".format(s_i) label = str(s_i) if not plot_only_aggregated_test: seed_ys = smooth(seed_ys, eval_smooth_factor) plot_with_shade_seed(0, ax[env_i], steps, seed_ys, None, seed_c, seed_c, label, legend=draw_eval_legend, xlim=[0, x_lim], ylim=[min_y, max_y], leg_size=leg_size, xlabel=f"Steps (1e6)", ylabel=ylabel, linewidth=linewidth, title=title) test_summary_dict[s_i][test_env] = seed_ys[-1] test_summary_dict_colors[s_i] = seed_c else: label = label_parser(m_id, load_pattern, label_parser_dict=label_parser_dict) if not plot_only_aggregated_test: if color_dict: color = color_dict[label] else: color = default_colors[model_i] label = label + "({})".format(n_seeds) if smooth_factor: means = smooth(means, eval_smooth_factor) stds = smooth(stds, eval_smooth_factor) plot_with_shade_grg( 0, ax[env_i], steps, means, stds, color, color, label, legend=draw_eval_legend, xlim=[0, x_lim+1], ylim=[0, max_y], xlabel=f"Env steps (1e6)" if env_i // (subplot_x) == subplot_y -1 else None, # only last line ylabel=ylabel if env_i % subplot_x == 0 else None, # only first row title=title, title_fontsize=title_fontsize, labelsize=fontsize, fontsize=fontsize, linewidth=linewidth, leg_linewidth=5, leg_args=leg_args, xnbins=xnbins, ynbins=ynbins, ) test_summary_dict[label][test_env] = means[-1] test_summary_dict_colors[label] = color if plot_aggregated_test: if plot_only_aggregated_test: agg_env_i = 0 else: agg_env_i = number_of_eval_envs - 1 # last one agg_means = np.array(agg_means) agg_mean = agg_means.mean(axis=0) agg_std = agg_means.std(axis=0) # std if smooth_factor and not per_seed: agg_mean = smooth(agg_mean, eval_smooth_factor) agg_std = smooth(agg_std, eval_smooth_factor) if color_dict: color = color_dict[re.sub("\([0-9]\)", '', label)] else: color = default_colors[model_i] if per_seed: print("Not smooth aggregated because of per seed") for s_i, (seed_ys, seed_st) in enumerate(zip(agg_mean, agg_std)): seed_c = default_colors[model_i + s_i] # label = m_id#+"(s:{})".format(s_i) label = str(s_i) # seed_ys = smooth(seed_ys, eval_smooth_factor) plot_with_shade_seed(0, ax if plot_only_aggregated_test else ax[agg_env_i], steps, seed_ys, seed_st, seed_c, seed_c, label, legend=draw_eval_legend, xlim=[0, x_lim], ylim=[min_y, max_y], labelsize=fontsize, filename=eval_filename, leg_size=leg_size, xlabel=f"Steps (1e6)", ylabel=ylabel, linewidth=1, title=agg_title) else: # just used for creating a dummy Imitation test figure -> delete # agg_mean = agg_mean * 0.1 # agg_std = agg_std * 0.1 # max_y = 1 plot_with_shade_grg( 0, ax if plot_only_aggregated_test else ax[agg_env_i], steps, agg_mean, agg_std, color, color, label, legend=draw_eval_legend, xlim=[0, x_lim + 1], ylim=[0, max_y], xlabel=f"Steps (1e6)" if plot_only_aggregated_test or (agg_env_i // (subplot_x) == subplot_y - 1) else None, # only last line ylabel=ylabel if plot_only_aggregated_test or (agg_env_i % subplot_x == 0) else None, # only first row title_fontsize=title_fontsize, title=agg_title, labelsize=fontsize, fontsize=fontsize, linewidth=linewidth, leg_linewidth=5, leg_args=leg_args, xnbins=xnbins, ynbins=ynbins, filename=eval_filename, ) # print summary means_dict = { lab: np.array(list(lab_sd.values())).mean() for lab, lab_sd in test_summary_dict.items() } best = max(means_dict.values()) pc = 0.3 n = int(len(means_dict) * pc) print("top n: ", n) top_pc = sorted(means_dict.values())[-n:] bottom_pc = sorted(means_dict.values())[:n] print("Legend:") cprint("\tbest", "green") cprint("\ttop {} %".format(pc), "blue") cprint("\tbottom {} %".format(pc), "red") print("\tothers") print() for l, l_mean in sorted(means_dict.items(), key=lambda kv: kv[1]): l_summary_dict = test_summary_dict[l] c = test_summary_dict_colors[l] print("label: {} ({})".format(l, c)) #print("\t{}({}) - Mean".format(l_mean, metric)) if l_mean == best: cprint("\t{}({}) - Mean".format(l_mean, eval_metric), "green") elif l_mean in top_pc: cprint("\t{}({}) - Mean".format(l_mean, eval_metric), "blue") elif l_mean in bottom_pc: cprint("\t{}({}) - Mean".format(l_mean, eval_metric), "red") else: print("\t{}({})".format(l_mean, eval_metric)) n_over_50 = 0 if sort_test: sorted_envs = sorted(l_summary_dict.items(), key=lambda kv: sort_test_set(env_name=kv[0])) else: sorted_envs = l_summary_dict.items() for tenv, p in sorted_envs: if p < 0.5: print("\t{:4f}({}) - \t{}".format(p, eval_metric, tenv)) else: print("\t{:4f}({}) -*\t{}".format(p, eval_metric, tenv)) n_over_50 += 1 print("\tenv over 50 - {}/{}".format(n_over_50, len(l_summary_dict))) if plot_test: plt.tight_layout() # plt.subplots_adjust(hspace=0.8, wspace=0.15, left=0.035, right=0.99, bottom=0.065, top=0.93) plt.show() if eval_filename is not None: plt.subplots_adjust(hspace=0.8, wspace=0.15, left=0.15, right=0.99, bottom=0.15, top=0.93) res= input(f"Save to {eval_filename} (y/n)?") if res == "y": f.savefig(eval_filename) print(f'saved to {eval_filename}') else: print('not saved')