|
""" |
|
This script computes the median difference and confidence intervals of all techniques from the ablation study for |
|
improving the masker evaluation metrics. The differences in the metrics are computed |
|
for all images of paired models, that is those which only differ in the inclusion or |
|
not of the given technique. Then, statistical inference is performed through the |
|
percentile bootstrap to obtain robust estimates of the differences in the metrics and |
|
confidence intervals. The script plots the summary for all techniques. |
|
""" |
|
print("Imports...", end="") |
|
from argparse import ArgumentParser |
|
import yaml |
|
import numpy as np |
|
import pandas as pd |
|
import seaborn as sns |
|
from scipy.special import comb |
|
from scipy.stats import trim_mean |
|
from tqdm import tqdm |
|
from collections import OrderedDict |
|
from pathlib import Path |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as mpatches |
|
import matplotlib.transforms as transforms |
|
|
|
|
|
|
|
|
|
|
|
|
|
dict_metrics = { |
|
"names": { |
|
"tpr": "TPR, Recall, Sensitivity", |
|
"tnr": "TNR, Specificity, Selectivity", |
|
"fpr": "FPR", |
|
"fpt": "False positives relative to image size", |
|
"fnr": "FNR, Miss rate", |
|
"fnt": "False negatives relative to image size", |
|
"mpr": "May positive rate (MPR)", |
|
"mnr": "May negative rate (MNR)", |
|
"accuracy": "Accuracy (ignoring may)", |
|
"error": "Error", |
|
"f05": "F05 score", |
|
"precision": "Precision", |
|
"edge_coherence": "Edge coherence", |
|
"accuracy_must_may": "Accuracy (ignoring cannot)", |
|
}, |
|
"key_metrics": ["error", "f05", "edge_coherence"], |
|
} |
|
|
|
dict_techniques = OrderedDict( |
|
[ |
|
("pseudo", "Pseudo labels"), |
|
("depth", "Depth (D)"), |
|
("seg", "Seg. (S)"), |
|
("spade", "SPADE"), |
|
("dada_seg", "DADA (S)"), |
|
("dada_masker", "DADA (M)"), |
|
] |
|
) |
|
|
|
|
|
model_feats = [ |
|
"masker", |
|
"seg", |
|
"depth", |
|
"dada_seg", |
|
"dada_masker", |
|
"spade", |
|
"pseudo", |
|
"ground", |
|
"instagan", |
|
] |
|
|
|
|
|
crest = sns.color_palette("crest", as_cmap=False, n_colors=7) |
|
palette_metrics = [crest[0], crest[3], crest[6]] |
|
sns.palplot(palette_metrics) |
|
|
|
|
|
dict_markers = {"error": "o", "f05": "s", "edge_coherence": "^"} |
|
|
|
|
|
def parsed_args(): |
|
""" |
|
Parse and returns command-line args |
|
|
|
Returns: |
|
argparse.Namespace: the parsed arguments |
|
""" |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--input_csv", |
|
default="ablations_metrics_20210311.csv", |
|
type=str, |
|
help="CSV containing the results of the ablation study", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
default=None, |
|
type=str, |
|
help="Output directory", |
|
) |
|
parser.add_argument( |
|
"--dpi", |
|
default=200, |
|
type=int, |
|
help="DPI for the output images", |
|
) |
|
parser.add_argument( |
|
"--n_bs", |
|
default=1e6, |
|
type=int, |
|
help="Number of bootrstrap samples", |
|
) |
|
parser.add_argument( |
|
"--alpha", |
|
default=0.99, |
|
type=float, |
|
help="Confidence level", |
|
) |
|
parser.add_argument( |
|
"--bs_seed", |
|
default=17, |
|
type=int, |
|
help="Bootstrap random seed, for reproducibility", |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def trim_mean_wrapper(a): |
|
return trim_mean(a, proportiontocut=0.2) |
|
|
|
|
|
def find_model_pairs(technique, model_feats): |
|
model_pairs = [] |
|
for mi in df.loc[df[technique]].model_feats.unique(): |
|
for mj in df.model_feats.unique(): |
|
if mj == mi: |
|
continue |
|
|
|
if df.loc[df.model_feats == mj, technique].unique()[0]: |
|
continue |
|
|
|
is_pair = True |
|
for f in model_feats: |
|
if f == technique: |
|
continue |
|
elif ( |
|
df.loc[df.model_feats == mj, f].unique()[0] |
|
!= df.loc[df.model_feats == mi, f].unique()[0] |
|
): |
|
is_pair = False |
|
break |
|
else: |
|
pass |
|
if is_pair: |
|
model_pairs.append((mi, mj)) |
|
break |
|
return model_pairs |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
args = parsed_args() |
|
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) |
|
|
|
|
|
if args.output_dir is None: |
|
output_dir = Path(os.environ["SLURM_TMPDIR"]) |
|
else: |
|
output_dir = Path(args.output_dir) |
|
if not output_dir.exists(): |
|
output_dir.mkdir(parents=True, exist_ok=False) |
|
|
|
|
|
output_yml = output_dir / "bootstrap_summary.yml" |
|
with open(output_yml, "w") as f: |
|
yaml.dump(vars(args), f) |
|
|
|
|
|
df = pd.read_csv(args.input_csv, index_col="model_img_idx") |
|
|
|
|
|
dfbs = pd.DataFrame(columns=["diff", "technique", "metric"]) |
|
for technique in model_feats: |
|
|
|
|
|
model_pairs = find_model_pairs(technique, model_feats) |
|
|
|
|
|
for m_with, m_without in model_pairs: |
|
df_with = df.loc[df.model_feats == m_with] |
|
df_without = df.loc[df.model_feats == m_without] |
|
for metric in dict_metrics["key_metrics"]: |
|
diff = ( |
|
df_with.sort_values(by="img_idx")[metric].values |
|
- df_without.sort_values(by="img_idx")[metric].values |
|
) |
|
dfm = pd.DataFrame.from_dict( |
|
{"metric": metric, "technique": technique, "diff": diff} |
|
) |
|
dfbs = dfbs.append(dfm, ignore_index=True) |
|
|
|
|
|
|
|
|
|
sns.reset_orig() |
|
sns.set(style="whitegrid") |
|
plt.rcParams.update({"font.family": "serif"}) |
|
plt.rcParams.update( |
|
{ |
|
"font.serif": [ |
|
"Computer Modern Roman", |
|
"Times New Roman", |
|
"Utopia", |
|
"New Century Schoolbook", |
|
"Century Schoolbook L", |
|
"ITC Bookman", |
|
"Bookman", |
|
"Times", |
|
"Palatino", |
|
"Charter", |
|
"serif" "Bitstream Vera Serif", |
|
"DejaVu Serif", |
|
] |
|
} |
|
) |
|
|
|
fig, axes = plt.subplots( |
|
nrows=1, ncols=3, sharey=True, dpi=args.dpi, figsize=(9, 3) |
|
) |
|
|
|
metrics = ["error", "f05", "edge_coherence"] |
|
dict_ci = {m: {} for m in metrics} |
|
|
|
for idx, metric in enumerate(dict_metrics["key_metrics"]): |
|
|
|
ax = sns.pointplot( |
|
ax=axes[idx], |
|
data=dfbs.loc[dfbs.metric.isin(["error", "f05", "edge_coherence"])], |
|
order=dict_techniques.keys(), |
|
x="diff", |
|
y="technique", |
|
hue="metric", |
|
hue_order=[metric], |
|
markers=dict_markers[metric], |
|
palette=[palette_metrics[idx]], |
|
errwidth=1.5, |
|
scale=0.6, |
|
join=False, |
|
estimator=trim_mean_wrapper, |
|
ci=int(args.alpha * 100), |
|
n_boot=args.n_bs, |
|
seed=args.bs_seed, |
|
) |
|
|
|
|
|
for line, technique in zip(ax.lines, dict_techniques.keys()): |
|
dict_ci[metric].update( |
|
{ |
|
technique: { |
|
"20_trimmed_mean": float( |
|
trim_mean_wrapper( |
|
dfbs.loc[ |
|
(dfbs.technique == technique) |
|
& (dfbs.metric == metrics[idx]), |
|
"diff", |
|
].values |
|
) |
|
), |
|
"ci_left": float(line.get_xdata()[0]), |
|
"ci_right": float(line.get_xdata()[1]), |
|
} |
|
} |
|
) |
|
|
|
leg_handles, leg_labels = ax.get_legend_handles_labels() |
|
|
|
|
|
sns.despine(left=True, bottom=True) |
|
|
|
|
|
ax.set_ylabel(None) |
|
|
|
|
|
ax.set_yticklabels(list(dict_techniques.values()), fontsize="medium") |
|
|
|
|
|
ax.set_xlabel(None) |
|
|
|
|
|
xticks = ax.get_xticks() |
|
xticklabels = xticks |
|
ax.set_xticks(xticks) |
|
ax.set_xticklabels(xticklabels, fontsize="small") |
|
|
|
|
|
display2data = ax.transData.inverted() |
|
ax2display = ax.transAxes |
|
_, y_bottom = display2data.transform(ax.transAxes.transform((0.0, 0.02))) |
|
_, y_top = display2data.transform(ax.transAxes.transform((0.0, 0.98))) |
|
ax.set_ylim(bottom=y_bottom, top=y_top) |
|
|
|
|
|
y = np.arange(ax.get_ylim()[1], ax.get_ylim()[0], 0.1) |
|
x = 0.0 * np.ones(y.shape[0]) |
|
ax.plot(x, y, linestyle=":", linewidth=1.5, color="black") |
|
|
|
|
|
xlim = ax.get_xlim() |
|
ylim = ax.get_ylim() |
|
if metric == "error": |
|
x0 = xlim[0] |
|
width = np.abs(x0) |
|
else: |
|
x0 = 0.0 |
|
width = np.abs(xlim[1]) |
|
trans = transforms.blended_transform_factory(ax.transData, ax.transAxes) |
|
rect = mpatches.Rectangle( |
|
xy=(x0, 0.0), |
|
width=width, |
|
height=1, |
|
transform=trans, |
|
linewidth=0.0, |
|
edgecolor="none", |
|
facecolor="gray", |
|
alpha=0.05, |
|
) |
|
ax.add_patch(rect) |
|
|
|
|
|
leg_handles, leg_labels = ax.get_legend_handles_labels() |
|
leg_labels = [dict_metrics["names"][metric] for metric in leg_labels] |
|
leg = ax.legend( |
|
handles=leg_handles, |
|
labels=leg_labels, |
|
loc="center", |
|
title="", |
|
bbox_to_anchor=(-0.2, 1.05, 1.0, 0.0), |
|
framealpha=1.0, |
|
frameon=False, |
|
handletextpad=-0.2, |
|
) |
|
|
|
|
|
fig.suptitle( |
|
"20 % trimmed mean difference and bootstrapped confidence intervals", |
|
y=0.0, |
|
fontsize="medium", |
|
) |
|
|
|
|
|
output_fig = output_dir / "bootstrap_summary.png" |
|
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") |
|
|
|
|
|
output_results = output_dir / "bootstrap_summary_results.yml" |
|
with open(output_results, "w") as f: |
|
yaml.dump(dict_ci, f) |
|
|