Spaces:
Runtime error
Runtime error
""" | |
This script computes the median difference and confidence intervals of all techniques from the ablation study for | |
improving the masker evaluation metrics. The differences in the metrics are computed | |
for all images of paired models, that is those which only differ in the inclusion or | |
not of the given technique. Then, statistical inference is performed through the | |
percentile bootstrap to obtain robust estimates of the differences in the metrics and | |
confidence intervals. The script plots the summary for all techniques. | |
""" | |
print("Imports...", end="") | |
from argparse import ArgumentParser | |
import yaml | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
from scipy.special import comb | |
from scipy.stats import trim_mean | |
from tqdm import tqdm | |
from collections import OrderedDict | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as mpatches | |
import matplotlib.transforms as transforms | |
# ----------------------- | |
# ----- Constants ----- | |
# ----------------------- | |
dict_metrics = { | |
"names": { | |
"tpr": "TPR, Recall, Sensitivity", | |
"tnr": "TNR, Specificity, Selectivity", | |
"fpr": "FPR", | |
"fpt": "False positives relative to image size", | |
"fnr": "FNR, Miss rate", | |
"fnt": "False negatives relative to image size", | |
"mpr": "May positive rate (MPR)", | |
"mnr": "May negative rate (MNR)", | |
"accuracy": "Accuracy (ignoring may)", | |
"error": "Error", | |
"f05": "F05 score", | |
"precision": "Precision", | |
"edge_coherence": "Edge coherence", | |
"accuracy_must_may": "Accuracy (ignoring cannot)", | |
}, | |
"key_metrics": ["error", "f05", "edge_coherence"], | |
} | |
dict_techniques = OrderedDict( | |
[ | |
("pseudo", "Pseudo labels"), | |
("depth", "Depth (D)"), | |
("seg", "Seg. (S)"), | |
("spade", "SPADE"), | |
("dada_seg", "DADA (S)"), | |
("dada_masker", "DADA (M)"), | |
] | |
) | |
# Model features | |
model_feats = [ | |
"masker", | |
"seg", | |
"depth", | |
"dada_seg", | |
"dada_masker", | |
"spade", | |
"pseudo", | |
"ground", | |
"instagan", | |
] | |
# Colors | |
crest = sns.color_palette("crest", as_cmap=False, n_colors=7) | |
palette_metrics = [crest[0], crest[3], crest[6]] | |
sns.palplot(palette_metrics) | |
# Markers | |
dict_markers = {"error": "o", "f05": "s", "edge_coherence": "^"} | |
def parsed_args(): | |
""" | |
Parse and returns command-line args | |
Returns: | |
argparse.Namespace: the parsed arguments | |
""" | |
parser = ArgumentParser() | |
parser.add_argument( | |
"--input_csv", | |
default="ablations_metrics_20210311.csv", | |
type=str, | |
help="CSV containing the results of the ablation study", | |
) | |
parser.add_argument( | |
"--output_dir", | |
default=None, | |
type=str, | |
help="Output directory", | |
) | |
parser.add_argument( | |
"--dpi", | |
default=200, | |
type=int, | |
help="DPI for the output images", | |
) | |
parser.add_argument( | |
"--n_bs", | |
default=1e6, | |
type=int, | |
help="Number of bootrstrap samples", | |
) | |
parser.add_argument( | |
"--alpha", | |
default=0.99, | |
type=float, | |
help="Confidence level", | |
) | |
parser.add_argument( | |
"--bs_seed", | |
default=17, | |
type=int, | |
help="Bootstrap random seed, for reproducibility", | |
) | |
return parser.parse_args() | |
def trim_mean_wrapper(a): | |
return trim_mean(a, proportiontocut=0.2) | |
def find_model_pairs(technique, model_feats): | |
model_pairs = [] | |
for mi in df.loc[df[technique]].model_feats.unique(): | |
for mj in df.model_feats.unique(): | |
if mj == mi: | |
continue | |
if df.loc[df.model_feats == mj, technique].unique()[0]: | |
continue | |
is_pair = True | |
for f in model_feats: | |
if f == technique: | |
continue | |
elif ( | |
df.loc[df.model_feats == mj, f].unique()[0] | |
!= df.loc[df.model_feats == mi, f].unique()[0] | |
): | |
is_pair = False | |
break | |
else: | |
pass | |
if is_pair: | |
model_pairs.append((mi, mj)) | |
break | |
return model_pairs | |
if __name__ == "__main__": | |
# ----------------------------- | |
# ----- Parse arguments ----- | |
# ----------------------------- | |
args = parsed_args() | |
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) | |
# Determine output dir | |
if args.output_dir is None: | |
output_dir = Path(os.environ["SLURM_TMPDIR"]) | |
else: | |
output_dir = Path(args.output_dir) | |
if not output_dir.exists(): | |
output_dir.mkdir(parents=True, exist_ok=False) | |
# Store args | |
output_yml = output_dir / "bootstrap_summary.yml" | |
with open(output_yml, "w") as f: | |
yaml.dump(vars(args), f) | |
# Read CSV | |
df = pd.read_csv(args.input_csv, index_col="model_img_idx") | |
# Build data set | |
dfbs = pd.DataFrame(columns=["diff", "technique", "metric"]) | |
for technique in model_feats: | |
# Get pairs | |
model_pairs = find_model_pairs(technique, model_feats) | |
# Compute differences | |
for m_with, m_without in model_pairs: | |
df_with = df.loc[df.model_feats == m_with] | |
df_without = df.loc[df.model_feats == m_without] | |
for metric in dict_metrics["key_metrics"]: | |
diff = ( | |
df_with.sort_values(by="img_idx")[metric].values | |
- df_without.sort_values(by="img_idx")[metric].values | |
) | |
dfm = pd.DataFrame.from_dict( | |
{"metric": metric, "technique": technique, "diff": diff} | |
) | |
dfbs = dfbs.append(dfm, ignore_index=True) | |
### Plot | |
# Set up plot | |
sns.reset_orig() | |
sns.set(style="whitegrid") | |
plt.rcParams.update({"font.family": "serif"}) | |
plt.rcParams.update( | |
{ | |
"font.serif": [ | |
"Computer Modern Roman", | |
"Times New Roman", | |
"Utopia", | |
"New Century Schoolbook", | |
"Century Schoolbook L", | |
"ITC Bookman", | |
"Bookman", | |
"Times", | |
"Palatino", | |
"Charter", | |
"serif" "Bitstream Vera Serif", | |
"DejaVu Serif", | |
] | |
} | |
) | |
fig, axes = plt.subplots( | |
nrows=1, ncols=3, sharey=True, dpi=args.dpi, figsize=(9, 3) | |
) | |
metrics = ["error", "f05", "edge_coherence"] | |
dict_ci = {m: {} for m in metrics} | |
for idx, metric in enumerate(dict_metrics["key_metrics"]): | |
ax = sns.pointplot( | |
ax=axes[idx], | |
data=dfbs.loc[dfbs.metric.isin(["error", "f05", "edge_coherence"])], | |
order=dict_techniques.keys(), | |
x="diff", | |
y="technique", | |
hue="metric", | |
hue_order=[metric], | |
markers=dict_markers[metric], | |
palette=[palette_metrics[idx]], | |
errwidth=1.5, | |
scale=0.6, | |
join=False, | |
estimator=trim_mean_wrapper, | |
ci=int(args.alpha * 100), | |
n_boot=args.n_bs, | |
seed=args.bs_seed, | |
) | |
# Retrieve confidence intervals and update results dictionary | |
for line, technique in zip(ax.lines, dict_techniques.keys()): | |
dict_ci[metric].update( | |
{ | |
technique: { | |
"20_trimmed_mean": float( | |
trim_mean_wrapper( | |
dfbs.loc[ | |
(dfbs.technique == technique) | |
& (dfbs.metric == metrics[idx]), | |
"diff", | |
].values | |
) | |
), | |
"ci_left": float(line.get_xdata()[0]), | |
"ci_right": float(line.get_xdata()[1]), | |
} | |
} | |
) | |
leg_handles, leg_labels = ax.get_legend_handles_labels() | |
# Change spines | |
sns.despine(left=True, bottom=True) | |
# Set Y-label | |
ax.set_ylabel(None) | |
# Y-tick labels | |
ax.set_yticklabels(list(dict_techniques.values()), fontsize="medium") | |
# Set X-label | |
ax.set_xlabel(None) | |
# X-ticks | |
xticks = ax.get_xticks() | |
xticklabels = xticks | |
ax.set_xticks(xticks) | |
ax.set_xticklabels(xticklabels, fontsize="small") | |
# Y-lim | |
display2data = ax.transData.inverted() | |
ax2display = ax.transAxes | |
_, y_bottom = display2data.transform(ax.transAxes.transform((0.0, 0.02))) | |
_, y_top = display2data.transform(ax.transAxes.transform((0.0, 0.98))) | |
ax.set_ylim(bottom=y_bottom, top=y_top) | |
# Draw line at H0 | |
y = np.arange(ax.get_ylim()[1], ax.get_ylim()[0], 0.1) | |
x = 0.0 * np.ones(y.shape[0]) | |
ax.plot(x, y, linestyle=":", linewidth=1.5, color="black") | |
# Draw gray area | |
xlim = ax.get_xlim() | |
ylim = ax.get_ylim() | |
if metric == "error": | |
x0 = xlim[0] | |
width = np.abs(x0) | |
else: | |
x0 = 0.0 | |
width = np.abs(xlim[1]) | |
trans = transforms.blended_transform_factory(ax.transData, ax.transAxes) | |
rect = mpatches.Rectangle( | |
xy=(x0, 0.0), | |
width=width, | |
height=1, | |
transform=trans, | |
linewidth=0.0, | |
edgecolor="none", | |
facecolor="gray", | |
alpha=0.05, | |
) | |
ax.add_patch(rect) | |
# Legend | |
leg_handles, leg_labels = ax.get_legend_handles_labels() | |
leg_labels = [dict_metrics["names"][metric] for metric in leg_labels] | |
leg = ax.legend( | |
handles=leg_handles, | |
labels=leg_labels, | |
loc="center", | |
title="", | |
bbox_to_anchor=(-0.2, 1.05, 1.0, 0.0), | |
framealpha=1.0, | |
frameon=False, | |
handletextpad=-0.2, | |
) | |
# Set X-label (title) │ | |
fig.suptitle( | |
"20 % trimmed mean difference and bootstrapped confidence intervals", | |
y=0.0, | |
fontsize="medium", | |
) | |
# Save figure | |
output_fig = output_dir / "bootstrap_summary.png" | |
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") | |
# Store results | |
output_results = output_dir / "bootstrap_summary_results.yml" | |
with open(output_results, "w") as f: | |
yaml.dump(dict_ci, f) | |