""" This script evaluates the contribution of a technique from the ablation study for improving the masker evaluation metrics. The differences in the metrics are computed for all images of paired models, that is those which only differ in the inclusion or not of the given technique. Then, statistical inference is performed through the percentile bootstrap to obtain robust estimates of the differences in the metrics and confidence intervals. The script plots the distribution of the bootrstraped estimates. """ print("Imports...", end="") from argparse import ArgumentParser import yaml import os import numpy as np import pandas as pd import seaborn as sns from scipy.stats import trim_mean from tqdm import tqdm from pathlib import Path import matplotlib.pyplot as plt import matplotlib.patches as mpatches # ----------------------- # ----- Constants ----- # ----------------------- dict_metrics = { "names": { "tpr": "TPR, Recall, Sensitivity", "tnr": "TNR, Specificity, Selectivity", "fpr": "FPR", "fpt": "False positives relative to image size", "fnr": "FNR, Miss rate", "fnt": "False negatives relative to image size", "mpr": "May positive rate (MPR)", "mnr": "May negative rate (MNR)", "accuracy": "Accuracy (ignoring may)", "error": "Error", "f05": "F05 score", "precision": "Precision", "edge_coherence": "Edge coherence", "accuracy_must_may": "Accuracy (ignoring cannot)", }, "key_metrics": ["f05", "error", "edge_coherence"], } dict_techniques = { "depth": "depth", "segmentation": "seg", "seg": "seg", "dada_s": "dada_seg", "dada_seg": "dada_seg", "dada_segmentation": "dada_seg", "dada_m": "dada_masker", "dada_masker": "dada_masker", "spade": "spade", "pseudo": "pseudo", "pseudo-labels": "pseudo", "pseudo_labels": "pseudo", } # Model features model_feats = [ "masker", "seg", "depth", "dada_seg", "dada_masker", "spade", "pseudo", "ground", "instagan", ] # Colors palette_colorblind = sns.color_palette("colorblind") color_cat1 = palette_colorblind[0] color_cat2 = palette_colorblind[1] palette_lightest = [ sns.light_palette(color_cat1, n_colors=20)[3], sns.light_palette(color_cat2, n_colors=20)[3], ] palette_light = [ sns.light_palette(color_cat1, n_colors=3)[1], sns.light_palette(color_cat2, n_colors=3)[1], ] palette_medium = [color_cat1, color_cat2] palette_dark = [ sns.dark_palette(color_cat1, n_colors=3)[1], sns.dark_palette(color_cat2, n_colors=3)[1], ] palette_cat1 = [ palette_lightest[0], palette_light[0], palette_medium[0], palette_dark[0], ] palette_cat2 = [ palette_lightest[1], palette_light[1], palette_medium[1], palette_dark[1], ] color_cat1_light = palette_light[0] color_cat2_light = palette_light[1] def parsed_args(): """ Parse and returns command-line args Returns: argparse.Namespace: the parsed arguments """ parser = ArgumentParser() parser.add_argument( "--input_csv", default="ablations_metrics_20210311.csv", type=str, help="CSV containing the results of the ablation study", ) parser.add_argument( "--output_dir", default=None, type=str, help="Output directory", ) parser.add_argument( "--technique", default=None, type=str, help="Keyword specifying the technique. One of: pseudo, depth, segmentation, dada_seg, dada_masker, spade", ) parser.add_argument( "--dpi", default=200, type=int, help="DPI for the output images", ) parser.add_argument( "--n_bs", default=1e6, type=int, help="Number of bootrstrap samples", ) parser.add_argument( "--alpha", default=0.99, type=float, help="Confidence level", ) parser.add_argument( "--bs_seed", default=17, type=int, help="Bootstrap random seed, for reproducibility", ) return parser.parse_args() def add_ci_mean( ax, sample_measure, bs_mean, bs_std, ci, color, alpha, fontsize, invert=False ): # Fill area between CI dist = ax.lines[0] dist_y = dist.get_ydata() dist_x = dist.get_xdata() linewidth = dist.get_linewidth() x_idx_low = np.argmin(np.abs(dist_x - ci[0])) x_idx_high = np.argmin(np.abs(dist_x - ci[1])) x_ci = dist_x[x_idx_low:x_idx_high] y_ci = dist_y[x_idx_low:x_idx_high] ax.fill_between(x_ci, 0, y_ci, facecolor=color, alpha=alpha) # Add vertical lines of CI ax.vlines( x=ci[0], ymin=0.0, ymax=y_ci[0], color=color, linewidth=linewidth, label="ci_low", ) ax.vlines( x=ci[1], ymin=0.0, ymax=y_ci[-1], color=color, linewidth=linewidth, label="ci_high", ) # Add annotations bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2) if invert: ha_l = "right" ha_u = "left" else: ha_l = "left" ha_u = "right" ax.text( ci[0], 0.0, s="L = {:.4f}".format(ci[0]), ha=ha_l, va="bottom", fontsize=fontsize, bbox=bbox_props, ) ax.text( ci[1], 0.0, s="U = {:.4f}".format(ci[1]), ha=ha_u, va="bottom", fontsize=fontsize, bbox=bbox_props, ) # Add vertical line of bootstrap mean x_idx_mean = np.argmin(np.abs(dist_x - bs_mean)) ax.vlines( x=bs_mean, ymin=0.0, ymax=dist_y[x_idx_mean], color="k", linewidth=linewidth ) # Add annotation of bootstrap mean bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2) ax.text( bs_mean, 0.6 * dist_y[x_idx_mean], s="Bootstrap mean = {:.4f}".format(bs_mean), ha="center", va="center", fontsize=fontsize, bbox=bbox_props, ) # Add vertical line of sample_measure x_idx_smeas = np.argmin(np.abs(dist_x - sample_measure)) ax.vlines( x=sample_measure, ymin=0.0, ymax=dist_y[x_idx_smeas], color="k", linewidth=linewidth, linestyles="dotted", ) # Add SD bbox_props = dict(boxstyle="darrow, pad=0.4", fc="w", ec="k", lw=2) ax.text( bs_mean, 0.4 * dist_y[x_idx_mean], s="SD = {:.4f} = SE".format(bs_std), ha="center", va="center", fontsize=fontsize, bbox=bbox_props, ) def add_null_pval(ax, null, color, alpha, fontsize): # Fill area between CI dist = ax.lines[0] dist_y = dist.get_ydata() dist_x = dist.get_xdata() linewidth = dist.get_linewidth() x_idx_null = np.argmin(np.abs(dist_x - null)) if x_idx_null >= (len(dist_x) / 2.0): x_pval = dist_x[x_idx_null:] y_pval = dist_y[x_idx_null:] else: x_pval = dist_x[:x_idx_null] y_pval = dist_y[:x_idx_null] ax.fill_between(x_pval, 0, y_pval, facecolor=color, alpha=alpha) # Add vertical lines of null dist = ax.lines[0] linewidth = dist.get_linewidth() y_max = ax.get_ylim()[1] ax.vlines( x=null, ymin=0.0, ymax=y_max, color="k", linewidth=linewidth, linestyles="dotted", ) # Add annotations bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2) ax.text( null, 0.75 * y_max, s="Null hypothesis = {:.1f}".format(null), ha="center", va="center", fontsize=fontsize, bbox=bbox_props, ) def plot_bootstrap_distr( sample_measure, bs_samples, alpha, color_ci, color_pval=None, null=None ): # Compute results from bootstrap q_low = (1.0 - alpha) / 2.0 q_high = 1.0 - q_low ci = np.quantile(bs_samples, [q_low, q_high]) bs_mean = np.mean(bs_samples) bs_std = np.std(bs_samples) if null is not None and color_pval is not None: pval_flag = True pval = np.min([[np.mean(bs_samples > null), np.mean(bs_samples < null)]]) * 2 else: pval_flag = False # Set up plot sns.set(style="whitegrid") fontsize = 24 font = {"family": "DejaVu Sans", "weight": "normal", "size": fontsize} plt.rc("font", **font) alpha_plot = 0.5 # Initialize the matplotlib figure fig, ax = plt.subplots(figsize=(30, 12), dpi=args.dpi) # Plot distribution of bootstrap means sns.kdeplot(bs_samples, color="b", linewidth=5, gridsize=1000, ax=ax) y_lim = ax.get_ylim() # Change spines sns.despine(left=True, bottom=True) # Annotations add_ci_mean( ax, sample_measure, bs_mean, bs_std, ci, color=color_ci, alpha=alpha_plot, fontsize=fontsize, ) if pval_flag: add_null_pval(ax, null, color=color_pval, alpha=alpha_plot, fontsize=fontsize) # Legend ci_patch = mpatches.Patch( facecolor=color_ci, edgecolor=None, alpha=alpha_plot, label="{:d} % confidence interval".format(int(100 * alpha)), ) if pval_flag: if pval == 0.0: pval_patch = mpatches.Patch( facecolor=color_pval, edgecolor=None, alpha=alpha_plot, label="P value / 2 = {:.1f}".format(pval / 2.0), ) elif np.around(pval / 2.0, decimals=4) > 0.0000: pval_patch = mpatches.Patch( facecolor=color_pval, edgecolor=None, alpha=alpha_plot, label="P value / 2 = {:.4f}".format(pval / 2.0), ) else: pval_patch = mpatches.Patch( facecolor=color_pval, edgecolor=None, alpha=alpha_plot, label="P value / 2 < $10^{}$".format(np.ceil(np.log10(pval / 2.0))), ) leg = ax.legend( handles=[ci_patch, pval_patch], ncol=1, loc="upper right", frameon=True, framealpha=1.0, title="", fontsize=fontsize, columnspacing=1.0, labelspacing=0.2, markerfirst=True, ) else: leg = ax.legend( handles=[ci_patch], ncol=1, loc="upper right", frameon=True, framealpha=1.0, title="", fontsize=fontsize, columnspacing=1.0, labelspacing=0.2, markerfirst=True, ) plt.setp(leg.get_title(), fontsize=fontsize, horizontalalignment="left") # Set X-label ax.set_xlabel("Bootstrap estimates", rotation=0, fontsize=fontsize, labelpad=10.0) # Set Y-label ax.set_ylabel("Density", rotation=90, fontsize=fontsize, labelpad=10.0) # Ticks plt.setp(ax.get_xticklabels(), fontsize=0.8 * fontsize, verticalalignment="top") plt.setp(ax.get_yticklabels(), fontsize=0.8 * fontsize) ax.set_ylim(y_lim) return fig, bs_mean, bs_std, ci, pval if __name__ == "__main__": # ----------------------------- # ----- Parse arguments ----- # ----------------------------- args = parsed_args() print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) # Determine output dir if args.output_dir is None: output_dir = Path(os.environ["SLURM_TMPDIR"]) else: output_dir = Path(args.output_dir) if not output_dir.exists(): output_dir.mkdir(parents=True, exist_ok=False) # Store args output_yml = output_dir / "{}_bootstrap.yml".format(args.technique) with open(output_yml, "w") as f: yaml.dump(vars(args), f) # Determine technique if args.technique.lower() not in dict_techniques: raise ValueError("{} is not a valid technique".format(args.technique)) else: technique = dict_techniques[args.technique.lower()] # Read CSV df = pd.read_csv(args.input_csv, index_col="model_img_idx") # Find relevant model pairs model_pairs = [] for mi in df.loc[df[technique]].model_feats.unique(): for mj in df.model_feats.unique(): if mj == mi: continue if df.loc[df.model_feats == mj, technique].unique()[0]: continue is_pair = True for f in model_feats: if f == technique: continue elif ( df.loc[df.model_feats == mj, f].unique()[0] != df.loc[df.model_feats == mi, f].unique()[0] ): is_pair = False break else: pass if is_pair: model_pairs.append((mi, mj)) break print("\nModel pairs identified:\n") for pair in model_pairs: print("{} & {}".format(pair[0], pair[1])) df["base"] = ["N/A"] * len(df) for spp in model_pairs: df.loc[df.model_feats.isin(spp), "depth_base"] = spp[1] # Build bootstrap data data = {m: [] for m in dict_metrics["key_metrics"]} for m_with, m_without in model_pairs: df_with = df.loc[df.model_feats == m_with] df_without = df.loc[df.model_feats == m_without] for metric in data.keys(): diff = ( df_with.sort_values(by="img_idx")[metric].values - df_without.sort_values(by="img_idx")[metric].values ) data[metric].extend(diff.tolist()) # Run bootstrap measures = ["mean", "median", "20_trimmed_mean"] bs_data = {meas: {m: np.zeros(args.n_bs) for m in data.keys()} for meas in measures} np.random.seed(args.bs_seed) for m, data_m in data.items(): for idx, s in enumerate(tqdm(range(args.n_bs))): # Sample with replacement bs_sample = np.random.choice(data_m, size=len(data_m), replace=True) # Store mean bs_data["mean"][m][idx] = np.mean(bs_sample) # Store median bs_data["median"][m][idx] = np.median(bs_sample) # Store 20 % trimmed mean bs_data["20_trimmed_mean"][m][idx] = trim_mean(bs_sample, 0.2) for metric in dict_metrics["key_metrics"]: sample_measure = trim_mean(data[metric], 0.2) fig, bs_mean, bs_std, ci, pval = plot_bootstrap_distr( sample_measure, bs_data["20_trimmed_mean"][metric], alpha=args.alpha, color_ci=color_cat1_light, color_pval=color_cat2_light, null=0.0, ) # Save figure output_fig = output_dir / "{}_bootstrap_{}_{}.png".format( args.technique, metric, "20_trimmed_mean" ) fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") # Store results output_results = output_dir / "{}_bootstrap_{}_{}.yml".format( args.technique, metric, "20_trimmed_mean" ) results_dict = { "measure": "20_trimmed_mean", "sample_measure": float(sample_measure), "bs_mean": float(bs_mean), "bs_std": float(bs_std), "ci_left": float(ci[0]), "ci_right": float(ci[1]), "pval": float(pval), } with open(output_results, "w") as f: yaml.dump(results_dict, f)