""" This script computes the median difference and confidence intervals of all techniques from the ablation study for improving the masker evaluation metrics. The differences in the metrics are computed for all images of paired models, that is those which only differ in the inclusion or not of the given technique. Then, statistical inference is performed through the percentile bootstrap to obtain robust estimates of the differences in the metrics and confidence intervals. The script plots the summary for all techniques. """ print("Imports...", end="") from argparse import ArgumentParser import yaml import numpy as np import pandas as pd import seaborn as sns from scipy.special import comb from scipy.stats import trim_mean from tqdm import tqdm from collections import OrderedDict from pathlib import Path import matplotlib.pyplot as plt import matplotlib.patches as mpatches import matplotlib.transforms as transforms # ----------------------- # ----- Constants ----- # ----------------------- dict_metrics = { "names": { "tpr": "TPR, Recall, Sensitivity", "tnr": "TNR, Specificity, Selectivity", "fpr": "FPR", "fpt": "False positives relative to image size", "fnr": "FNR, Miss rate", "fnt": "False negatives relative to image size", "mpr": "May positive rate (MPR)", "mnr": "May negative rate (MNR)", "accuracy": "Accuracy (ignoring may)", "error": "Error", "f05": "F05 score", "precision": "Precision", "edge_coherence": "Edge coherence", "accuracy_must_may": "Accuracy (ignoring cannot)", }, "key_metrics": ["error", "f05", "edge_coherence"], } dict_techniques = OrderedDict( [ ("pseudo", "Pseudo labels"), ("depth", "Depth (D)"), ("seg", "Seg. (S)"), ("spade", "SPADE"), ("dada_seg", "DADA (S)"), ("dada_masker", "DADA (M)"), ] ) # Model features model_feats = [ "masker", "seg", "depth", "dada_seg", "dada_masker", "spade", "pseudo", "ground", "instagan", ] # Colors crest = sns.color_palette("crest", as_cmap=False, n_colors=7) palette_metrics = [crest[0], crest[3], crest[6]] sns.palplot(palette_metrics) # Markers dict_markers = {"error": "o", "f05": "s", "edge_coherence": "^"} def parsed_args(): """ Parse and returns command-line args Returns: argparse.Namespace: the parsed arguments """ parser = ArgumentParser() parser.add_argument( "--input_csv", default="ablations_metrics_20210311.csv", type=str, help="CSV containing the results of the ablation study", ) parser.add_argument( "--output_dir", default=None, type=str, help="Output directory", ) parser.add_argument( "--dpi", default=200, type=int, help="DPI for the output images", ) parser.add_argument( "--n_bs", default=1e6, type=int, help="Number of bootrstrap samples", ) parser.add_argument( "--alpha", default=0.99, type=float, help="Confidence level", ) parser.add_argument( "--bs_seed", default=17, type=int, help="Bootstrap random seed, for reproducibility", ) return parser.parse_args() def trim_mean_wrapper(a): return trim_mean(a, proportiontocut=0.2) def find_model_pairs(technique, model_feats): model_pairs = [] for mi in df.loc[df[technique]].model_feats.unique(): for mj in df.model_feats.unique(): if mj == mi: continue if df.loc[df.model_feats == mj, technique].unique()[0]: continue is_pair = True for f in model_feats: if f == technique: continue elif ( df.loc[df.model_feats == mj, f].unique()[0] != df.loc[df.model_feats == mi, f].unique()[0] ): is_pair = False break else: pass if is_pair: model_pairs.append((mi, mj)) break return model_pairs if __name__ == "__main__": # ----------------------------- # ----- Parse arguments ----- # ----------------------------- args = parsed_args() print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) # Determine output dir if args.output_dir is None: output_dir = Path(os.environ["SLURM_TMPDIR"]) else: output_dir = Path(args.output_dir) if not output_dir.exists(): output_dir.mkdir(parents=True, exist_ok=False) # Store args output_yml = output_dir / "bootstrap_summary.yml" with open(output_yml, "w") as f: yaml.dump(vars(args), f) # Read CSV df = pd.read_csv(args.input_csv, index_col="model_img_idx") # Build data set dfbs = pd.DataFrame(columns=["diff", "technique", "metric"]) for technique in model_feats: # Get pairs model_pairs = find_model_pairs(technique, model_feats) # Compute differences for m_with, m_without in model_pairs: df_with = df.loc[df.model_feats == m_with] df_without = df.loc[df.model_feats == m_without] for metric in dict_metrics["key_metrics"]: diff = ( df_with.sort_values(by="img_idx")[metric].values - df_without.sort_values(by="img_idx")[metric].values ) dfm = pd.DataFrame.from_dict( {"metric": metric, "technique": technique, "diff": diff} ) dfbs = dfbs.append(dfm, ignore_index=True) ### Plot # Set up plot sns.reset_orig() sns.set(style="whitegrid") plt.rcParams.update({"font.family": "serif"}) plt.rcParams.update( { "font.serif": [ "Computer Modern Roman", "Times New Roman", "Utopia", "New Century Schoolbook", "Century Schoolbook L", "ITC Bookman", "Bookman", "Times", "Palatino", "Charter", "serif" "Bitstream Vera Serif", "DejaVu Serif", ] } ) fig, axes = plt.subplots( nrows=1, ncols=3, sharey=True, dpi=args.dpi, figsize=(9, 3) ) metrics = ["error", "f05", "edge_coherence"] dict_ci = {m: {} for m in metrics} for idx, metric in enumerate(dict_metrics["key_metrics"]): ax = sns.pointplot( ax=axes[idx], data=dfbs.loc[dfbs.metric.isin(["error", "f05", "edge_coherence"])], order=dict_techniques.keys(), x="diff", y="technique", hue="metric", hue_order=[metric], markers=dict_markers[metric], palette=[palette_metrics[idx]], errwidth=1.5, scale=0.6, join=False, estimator=trim_mean_wrapper, ci=int(args.alpha * 100), n_boot=args.n_bs, seed=args.bs_seed, ) # Retrieve confidence intervals and update results dictionary for line, technique in zip(ax.lines, dict_techniques.keys()): dict_ci[metric].update( { technique: { "20_trimmed_mean": float( trim_mean_wrapper( dfbs.loc[ (dfbs.technique == technique) & (dfbs.metric == metrics[idx]), "diff", ].values ) ), "ci_left": float(line.get_xdata()[0]), "ci_right": float(line.get_xdata()[1]), } } ) leg_handles, leg_labels = ax.get_legend_handles_labels() # Change spines sns.despine(left=True, bottom=True) # Set Y-label ax.set_ylabel(None) # Y-tick labels ax.set_yticklabels(list(dict_techniques.values()), fontsize="medium") # Set X-label ax.set_xlabel(None) # X-ticks xticks = ax.get_xticks() xticklabels = xticks ax.set_xticks(xticks) ax.set_xticklabels(xticklabels, fontsize="small") # Y-lim display2data = ax.transData.inverted() ax2display = ax.transAxes _, y_bottom = display2data.transform(ax.transAxes.transform((0.0, 0.02))) _, y_top = display2data.transform(ax.transAxes.transform((0.0, 0.98))) ax.set_ylim(bottom=y_bottom, top=y_top) # Draw line at H0 y = np.arange(ax.get_ylim()[1], ax.get_ylim()[0], 0.1) x = 0.0 * np.ones(y.shape[0]) ax.plot(x, y, linestyle=":", linewidth=1.5, color="black") # Draw gray area xlim = ax.get_xlim() ylim = ax.get_ylim() if metric == "error": x0 = xlim[0] width = np.abs(x0) else: x0 = 0.0 width = np.abs(xlim[1]) trans = transforms.blended_transform_factory(ax.transData, ax.transAxes) rect = mpatches.Rectangle( xy=(x0, 0.0), width=width, height=1, transform=trans, linewidth=0.0, edgecolor="none", facecolor="gray", alpha=0.05, ) ax.add_patch(rect) # Legend leg_handles, leg_labels = ax.get_legend_handles_labels() leg_labels = [dict_metrics["names"][metric] for metric in leg_labels] leg = ax.legend( handles=leg_handles, labels=leg_labels, loc="center", title="", bbox_to_anchor=(-0.2, 1.05, 1.0, 0.0), framealpha=1.0, frameon=False, handletextpad=-0.2, ) # Set X-label (title) │ fig.suptitle( "20 % trimmed mean difference and bootstrapped confidence intervals", y=0.0, fontsize="medium", ) # Save figure output_fig = output_dir / "bootstrap_summary.png" fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") # Store results output_results = output_dir / "bootstrap_summary_results.yml" with open(output_results, "w") as f: yaml.dump(dict_ci, f)