|
""" |
|
This scripts plots examples of the images that get best and worse metrics |
|
""" |
|
print("Imports...", end="") |
|
import os |
|
import sys |
|
from argparse import ArgumentParser |
|
from pathlib import Path |
|
|
|
import matplotlib.patches as mpatches |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
import seaborn as sns |
|
import yaml |
|
from imageio import imread |
|
from skimage.color import rgba2rgb |
|
from sklearn.metrics.pairwise import euclidean_distances |
|
|
|
sys.path.append("../") |
|
|
|
from climategan.data import encode_mask_label |
|
from climategan.eval_metrics import edges_coherence_std_min |
|
from eval_masker import crop_and_resize |
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics = ["error", "f05", "edge_coherence"] |
|
|
|
dict_metrics = { |
|
"names": { |
|
"tpr": "TPR, Recall, Sensitivity", |
|
"tnr": "TNR, Specificity, Selectivity", |
|
"fpr": "FPR", |
|
"fpt": "False positives relative to image size", |
|
"fnr": "FNR, Miss rate", |
|
"fnt": "False negatives relative to image size", |
|
"mpr": "May positive rate (MPR)", |
|
"mnr": "May negative rate (MNR)", |
|
"accuracy": "Accuracy (ignoring may)", |
|
"error": "Error", |
|
"f05": "F05 score", |
|
"precision": "Precision", |
|
"edge_coherence": "Edge coherence", |
|
"accuracy_must_may": "Accuracy (ignoring cannot)", |
|
}, |
|
"key_metrics": ["error", "f05", "edge_coherence"], |
|
} |
|
|
|
|
|
|
|
colorblind_palette = sns.color_palette("colorblind") |
|
color_cannot = colorblind_palette[1] |
|
color_must = colorblind_palette[2] |
|
color_may = colorblind_palette[7] |
|
color_pred = colorblind_palette[4] |
|
|
|
icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5) |
|
color_tp = icefire[0] |
|
color_tn = icefire[1] |
|
color_fp = icefire[4] |
|
color_fn = icefire[3] |
|
|
|
|
|
def parsed_args(): |
|
""" |
|
Parse and returns command-line args |
|
|
|
Returns: |
|
argparse.Namespace: the parsed arguments |
|
""" |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--input_csv", |
|
default="ablations_metrics_20210311.csv", |
|
type=str, |
|
help="CSV containing the results of the ablation study", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
default=None, |
|
type=str, |
|
help="Output directory", |
|
) |
|
parser.add_argument( |
|
"--models_log_path", |
|
default=None, |
|
type=str, |
|
help="Path containing the log files of the models", |
|
) |
|
parser.add_argument( |
|
"--masker_test_set_dir", |
|
default=None, |
|
type=str, |
|
help="Directory containing the test images", |
|
) |
|
parser.add_argument( |
|
"--best_model", |
|
default="dada, msd_spade, pseudo", |
|
type=str, |
|
help="The string identifier of the best model", |
|
) |
|
parser.add_argument( |
|
"--dpi", |
|
default=200, |
|
type=int, |
|
help="DPI for the output images", |
|
) |
|
parser.add_argument( |
|
"--alpha", |
|
default=0.5, |
|
type=float, |
|
help="Transparency of labels shade", |
|
) |
|
parser.add_argument( |
|
"--percentile", |
|
default=0.05, |
|
type=float, |
|
help="Transparency of labels shade", |
|
) |
|
parser.add_argument( |
|
"--seed", |
|
default=None, |
|
type=int, |
|
help="Bootstrap random seed, for reproducibility", |
|
) |
|
parser.add_argument( |
|
"--no_images", |
|
action="store_true", |
|
default=False, |
|
help="Do not generate images", |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def map_color(arr, input_color, output_color, rtol=1e-09): |
|
""" |
|
Maps one color to another |
|
""" |
|
input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,))) |
|
output = arr.copy() |
|
output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color |
|
return output |
|
|
|
|
|
def plot_labels(ax, img, label, img_id, do_legend): |
|
label_colmap = label.astype(float) |
|
label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot) |
|
label_colmap = map_color(label_colmap, (0, 0, 255), color_must) |
|
label_colmap = map_color(label_colmap, (0, 0, 0), color_may) |
|
|
|
ax.imshow(img) |
|
ax.imshow(label_colmap, alpha=0.5) |
|
ax.axis("off") |
|
|
|
|
|
ax.annotate( |
|
xy=(0.05, 0.95), |
|
xycoords="axes fraction", |
|
xytext=(0.05, 0.95), |
|
textcoords="axes fraction", |
|
text=img_id, |
|
fontsize="x-large", |
|
verticalalignment="top", |
|
color="white", |
|
) |
|
|
|
|
|
if do_legend: |
|
handles = [] |
|
lw = 1.0 |
|
handles.append( |
|
mpatches.Patch(facecolor=color_must, label="must", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch(facecolor=color_may, label="must", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch( |
|
facecolor=color_cannot, label="must", linewidth=lw, alpha=0.66 |
|
) |
|
) |
|
labels = ["Must-be-flooded", "May-be-flooded", "Cannot-be-flooded"] |
|
ax.legend( |
|
handles=handles, |
|
labels=labels, |
|
bbox_to_anchor=(0.0, 1.0, 1.0, 0.075), |
|
ncol=3, |
|
mode="expand", |
|
fontsize="xx-small", |
|
frameon=False, |
|
) |
|
|
|
|
|
def plot_pred(ax, img, pred, img_id, do_legend): |
|
pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3)) |
|
|
|
pred_colmap = pred.astype(float) |
|
pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred) |
|
pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred) |
|
pred_colmap_ma = pred_colmap_ma.mask * img + pred_colmap_ma |
|
|
|
ax.imshow(img) |
|
ax.imshow(pred_colmap_ma, alpha=0.5) |
|
ax.axis("off") |
|
|
|
|
|
ax.annotate( |
|
xy=(0.05, 0.95), |
|
xycoords="axes fraction", |
|
xytext=(0.05, 0.95), |
|
textcoords="axes fraction", |
|
text=img_id, |
|
fontsize="x-large", |
|
verticalalignment="top", |
|
color="white", |
|
) |
|
|
|
|
|
if do_legend: |
|
handles = [] |
|
lw = 1.0 |
|
handles.append( |
|
mpatches.Patch(facecolor=color_pred, label="must", linewidth=lw, alpha=0.66) |
|
) |
|
labels = ["Prediction"] |
|
ax.legend( |
|
handles=handles, |
|
labels=labels, |
|
bbox_to_anchor=(0.0, 1.0, 1.0, 0.075), |
|
ncol=3, |
|
mode="expand", |
|
fontsize="xx-small", |
|
frameon=False, |
|
) |
|
|
|
|
|
def plot_correct_incorrect(ax, img_filename, img, label, img_id, do_legend): |
|
|
|
fp_map = imread( |
|
model_path / "eval-metrics/fp" / "{}_fp.png".format(Path(img_filename).stem) |
|
) |
|
fp_map = np.tile(np.expand_dims(fp_map, axis=2), reps=(1, 1, 3)) |
|
|
|
fp_map_colmap = fp_map.astype(float) |
|
fp_map_colmap = map_color(fp_map_colmap, (1, 1, 1), color_fp) |
|
|
|
|
|
fn_map = imread( |
|
model_path / "eval-metrics/fn" / "{}_fn.png".format(Path(img_filename).stem) |
|
) |
|
fn_map = np.tile(np.expand_dims(fn_map, axis=2), reps=(1, 1, 3)) |
|
|
|
fn_map_colmap = fn_map.astype(float) |
|
fn_map_colmap = map_color(fn_map_colmap, (1, 1, 1), color_fn) |
|
|
|
|
|
tp_map = imread( |
|
model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(img_filename).stem) |
|
) |
|
tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3)) |
|
|
|
tp_map_colmap = tp_map.astype(float) |
|
tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp) |
|
|
|
|
|
tn_map = imread( |
|
model_path / "eval-metrics/tn" / "{}_tn.png".format(Path(img_filename).stem) |
|
) |
|
tn_map = np.tile(np.expand_dims(tn_map, axis=2), reps=(1, 1, 3)) |
|
|
|
tn_map_colmap = tn_map.astype(float) |
|
tn_map_colmap = map_color(tn_map_colmap, (1, 1, 1), color_tn) |
|
|
|
label_colmap = label.astype(float) |
|
label_colmap = map_color(label_colmap, (0, 0, 0), color_may) |
|
label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_may) |
|
label_colmap_ma = label_colmap_ma.mask * img + label_colmap_ma |
|
|
|
|
|
maps = fp_map_colmap + fn_map_colmap + tp_map_colmap + tn_map_colmap |
|
maps_ma = np.ma.masked_equal(maps, (0, 0, 0)) |
|
maps_ma = maps_ma.mask * img + maps_ma |
|
|
|
ax.imshow(img) |
|
ax.imshow(label_colmap_ma, alpha=0.5) |
|
ax.imshow(maps_ma, alpha=0.5) |
|
ax.axis("off") |
|
|
|
|
|
ax.annotate( |
|
xy=(0.05, 0.95), |
|
xycoords="axes fraction", |
|
xytext=(0.05, 0.95), |
|
textcoords="axes fraction", |
|
text=img_id, |
|
fontsize="x-large", |
|
verticalalignment="top", |
|
color="white", |
|
) |
|
|
|
|
|
if do_legend: |
|
handles = [] |
|
lw = 1.0 |
|
handles.append( |
|
mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch(facecolor=color_tn, label="TN", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch(facecolor=color_fp, label="FP", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch(facecolor=color_fn, label="FN", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch( |
|
facecolor=color_may, label="May-be-flooded", linewidth=lw, alpha=0.66 |
|
) |
|
) |
|
labels = ["TP", "TN", "FP", "FN", "May-be-flooded"] |
|
ax.legend( |
|
handles=handles, |
|
labels=labels, |
|
bbox_to_anchor=(0.0, 1.0, 1.0, 0.075), |
|
ncol=5, |
|
mode="expand", |
|
fontsize="xx-small", |
|
frameon=False, |
|
) |
|
|
|
|
|
def plot_edge_coherence(ax, img, label, pred, img_id, do_legend): |
|
pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3)) |
|
|
|
ec, pred_ec, label_ec = edges_coherence_std_min( |
|
np.squeeze(pred[:, :, 0]), np.squeeze(encode_mask_label(label, "flood")) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_ec_coord = np.argwhere(pred_ec > 0) |
|
label_ec_coord = np.argwhere(label_ec > 0) |
|
|
|
|
|
dist_mat = np.divide( |
|
euclidean_distances(pred_ec_coord, label_ec_coord), pred_ec.shape[0] |
|
) |
|
|
|
|
|
min_dist = np.min(dist_mat, axis=1) |
|
|
|
|
|
|
|
|
|
|
|
pred_ec = np.tile( |
|
np.expand_dims(np.asarray(pred_ec > 0, dtype=float), axis=2), reps=(1, 1, 3) |
|
) |
|
pred_ec_colmap = map_color(pred_ec, (1, 1, 1), color_pred) |
|
pred_ec_colmap_ma = np.ma.masked_not_equal(pred_ec_colmap, color_pred) |
|
|
|
label_ec = np.tile( |
|
np.expand_dims(np.asarray(label_ec > 0, dtype=float), axis=2), reps=(1, 1, 3) |
|
) |
|
label_ec_colmap = map_color(label_ec, (1, 1, 1), color_must) |
|
label_ec_colmap_ma = np.ma.masked_not_equal( |
|
label_ec_colmap, color_must |
|
) |
|
|
|
|
|
combined_ec = pred_ec_colmap + label_ec_colmap |
|
combined_ec_ma = np.ma.masked_equal(combined_ec, (0, 0, 0)) |
|
combined_ec_img = combined_ec_ma.mask * img + combined_ec |
|
|
|
|
|
pred_colmap = pred.astype(float) |
|
pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred) |
|
pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred) |
|
|
|
|
|
label_colmap = label.astype(float) |
|
label_colmap = map_color(label_colmap, (0, 0, 255), color_must) |
|
label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_must) |
|
|
|
|
|
tp_map = imread( |
|
model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(srs_sel.filename).stem) |
|
) |
|
tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3)) |
|
tp_map_colmap = tp_map.astype(float) |
|
tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp) |
|
tp_map_colmap_ma = np.ma.masked_not_equal(tp_map_colmap, color_tp) |
|
|
|
|
|
comb_pred = ( |
|
(pred_colmap_ma.mask ^ tp_map_colmap_ma.mask) |
|
& tp_map_colmap_ma.mask |
|
& combined_ec_ma.mask |
|
) * pred_colmap |
|
comb_label = ( |
|
(label_colmap_ma.mask ^ pred_colmap_ma.mask) |
|
& pred_colmap_ma.mask |
|
& combined_ec_ma.mask |
|
) * label_colmap |
|
comb_tp = combined_ec_ma.mask * tp_map_colmap.copy() |
|
combined = comb_tp + comb_label + comb_pred |
|
combined_ma = np.ma.masked_equal(combined, (0, 0, 0)) |
|
combined_ma = combined_ma.mask * combined_ec_img + combined_ma |
|
|
|
ax.imshow(combined_ec_img, alpha=1) |
|
ax.imshow(combined_ma, alpha=0.5) |
|
ax.axis("off") |
|
|
|
|
|
idx_sort_x = np.argsort(pred_ec_coord[:, 1]) |
|
offset = 100 |
|
for idx in range(offset, pred_ec_coord.shape[0], offset): |
|
y0, x0 = pred_ec_coord[idx_sort_x[idx], :] |
|
argmin = np.argmin(dist_mat[idx_sort_x[idx]]) |
|
y1, x1 = label_ec_coord[argmin, :] |
|
ax.plot([x0, x1], [y0, y1], color="white", linewidth=0.5) |
|
|
|
|
|
ax.annotate( |
|
xy=(0.05, 0.95), |
|
xycoords="axes fraction", |
|
xytext=(0.05, 0.95), |
|
textcoords="axes fraction", |
|
text=img_id, |
|
fontsize="x-large", |
|
verticalalignment="top", |
|
color="white", |
|
) |
|
|
|
if do_legend: |
|
handles = [] |
|
lw = 1.0 |
|
handles.append( |
|
mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch(facecolor=color_pred, label="pred", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch( |
|
facecolor=color_must, label="Must-be-flooded", linewidth=lw, alpha=0.66 |
|
) |
|
) |
|
labels = ["TP", "Prediction", "Must-be-flooded"] |
|
ax.legend( |
|
handles=handles, |
|
labels=labels, |
|
bbox_to_anchor=(0.0, 1.0, 1.0, 0.075), |
|
ncol=3, |
|
mode="expand", |
|
fontsize="xx-small", |
|
frameon=False, |
|
) |
|
|
|
|
|
def plot_images_metric(axes, metric, img_filename, img_id, do_legend): |
|
|
|
|
|
img_path = imgs_orig_path / img_filename |
|
label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem) |
|
img, label = crop_and_resize(img_path, label_path) |
|
img = rgba2rgb(img) if img.shape[-1] == 4 else img / 255.0 |
|
pred = imread( |
|
model_path / "eval-metrics/pred" / "{}_pred.png".format(Path(img_filename).stem) |
|
) |
|
|
|
|
|
plot_labels(axes[0], img, label, img_id, do_legend) |
|
|
|
|
|
plot_pred(axes[1], img, pred, img_id, do_legend) |
|
|
|
|
|
if metric in ["error", "f05"]: |
|
plot_correct_incorrect(axes[2], img_filename, img, label, img_id, do_legend) |
|
|
|
elif metric == "edge_coherence": |
|
plot_edge_coherence(axes[2], img, label, pred, img_id, do_legend) |
|
else: |
|
raise ValueError |
|
|
|
|
|
def scatterplot_metrics_pair(ax, df, x_metric, y_metric, dict_images): |
|
|
|
sns.scatterplot(data=df, x=x_metric, y=y_metric, ax=ax) |
|
|
|
|
|
ax.set_xlabel(dict_metrics["names"][x_metric], rotation=0, fontsize="medium") |
|
|
|
|
|
ax.set_ylabel(dict_metrics["names"][y_metric], rotation=90, fontsize="medium") |
|
|
|
|
|
sns.despine(ax=ax, left=True, bottom=True) |
|
|
|
annotate_scatterplot(ax, dict_images, x_metric, y_metric) |
|
|
|
|
|
def scatterplot_metrics(ax, df, dict_images): |
|
|
|
sns.scatterplot(data=df, x="error", y="f05", hue="edge_coherence", ax=ax) |
|
|
|
|
|
ax.set_xlabel(dict_metrics["names"]["error"], rotation=0, fontsize="medium") |
|
|
|
|
|
ax.set_ylabel(dict_metrics["names"]["f05"], rotation=90, fontsize="medium") |
|
|
|
annotate_scatterplot(ax, dict_images, "error", "f05") |
|
|
|
|
|
sns.despine(ax=ax, left=True, bottom=True) |
|
|
|
|
|
xlim = ax.get_xlim() |
|
ylim = ax.get_ylim() |
|
ax.set_xlim([0.0, xlim[1]]) |
|
ax.set_ylim([ylim[0], 1.0]) |
|
|
|
|
|
def annotate_scatterplot(ax, dict_images, x_metric, y_metric, offset=0.1): |
|
xlim = ax.get_xlim() |
|
ylim = ax.get_ylim() |
|
x_len = xlim[1] - xlim[0] |
|
y_len = ylim[1] - ylim[0] |
|
x_th = xlim[1] - x_len / 2.0 |
|
y_th = ylim[1] - y_len / 2.0 |
|
for text, d in dict_images.items(): |
|
x = d[x_metric] |
|
y = d[y_metric] |
|
x_text = x + x_len * offset if x < x_th else x - x_len * offset |
|
y_text = y + y_len * offset if y < y_th else y - y_len * offset |
|
ax.annotate( |
|
xy=(x, y), |
|
xycoords="data", |
|
xytext=(x_text, y_text), |
|
textcoords="data", |
|
text=text, |
|
arrowprops=dict(facecolor="black", shrink=0.05), |
|
fontsize="medium", |
|
color="black", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
args = parsed_args() |
|
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) |
|
|
|
|
|
if args.output_dir is None: |
|
output_dir = Path(os.environ["SLURM_TMPDIR"]) |
|
else: |
|
output_dir = Path(args.output_dir) |
|
if not output_dir.exists(): |
|
output_dir.mkdir(parents=True, exist_ok=False) |
|
|
|
|
|
output_yml = output_dir / "labels.yml" |
|
with open(output_yml, "w") as f: |
|
yaml.dump(vars(args), f) |
|
|
|
|
|
imgs_orig_path = Path(args.masker_test_set_dir) / "imgs" |
|
labels_path = Path(args.masker_test_set_dir) / "labels" |
|
|
|
|
|
df = pd.read_csv(args.input_csv, index_col="model_img_idx") |
|
|
|
|
|
df = df.loc[df.model_feats == args.best_model] |
|
v_key, model_dir = df.model.unique()[0].split("/") |
|
model_path = Path(args.models_log_path) / "ablation-{}".format(v_key) / model_dir |
|
|
|
|
|
sns.reset_orig() |
|
sns.set(style="whitegrid") |
|
plt.rcParams.update({"font.family": "serif"}) |
|
plt.rcParams.update( |
|
{ |
|
"font.serif": [ |
|
"Computer Modern Roman", |
|
"Times New Roman", |
|
"Utopia", |
|
"New Century Schoolbook", |
|
"Century Schoolbook L", |
|
"ITC Bookman", |
|
"Bookman", |
|
"Times", |
|
"Palatino", |
|
"Charter", |
|
"serif" "Bitstream Vera Serif", |
|
"DejaVu Serif", |
|
] |
|
} |
|
) |
|
|
|
if args.seed: |
|
np.random.seed(args.seed) |
|
img_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" |
|
dict_images = {} |
|
idx = 0 |
|
for metric in metrics: |
|
|
|
fig, axes = plt.subplots(nrows=2, ncols=3, dpi=200, figsize=(18, 12)) |
|
|
|
|
|
if metric == "error": |
|
ascending = True |
|
else: |
|
ascending = False |
|
idx_rand = np.random.permutation(int(args.percentile * len(df)))[0] |
|
srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand] |
|
img_id = img_ids[idx] |
|
dict_images.update({img_id: srs_sel}) |
|
|
|
|
|
img_filename = srs_sel.filename |
|
|
|
if not args.no_images: |
|
axes_row = axes[0, :] |
|
plot_images_metric(axes_row, metric, img_filename, img_id, do_legend=True) |
|
|
|
idx += 1 |
|
|
|
|
|
if metric == "error": |
|
ascending = False |
|
else: |
|
ascending = True |
|
idx_rand = np.random.permutation(int(args.percentile * len(df)))[0] |
|
srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand] |
|
img_id = img_ids[idx] |
|
dict_images.update({img_id: srs_sel}) |
|
|
|
|
|
img_filename = srs_sel.filename |
|
|
|
if not args.no_images: |
|
axes_row = axes[1, :] |
|
plot_images_metric(axes_row, metric, img_filename, img_id, do_legend=False) |
|
|
|
idx += 1 |
|
|
|
|
|
output_fig = output_dir / "{}.png".format(metric) |
|
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") |
|
|
|
fig = plt.figure(dpi=200) |
|
scatterplot_metrics(fig.gca(), df, dict_images) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_fig = output_dir / "scatterplots.png" |
|
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") |
|
|