climateGAN / figures /metrics.py
vict0rsch's picture
initial commit from cc-ai/climateGAN
448ebbd
raw
history blame
20.4 kB
"""
This scripts plots examples of the images that get best and worse metrics
"""
print("Imports...", end="")
import os
import sys
from argparse import ArgumentParser
from pathlib import Path
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import yaml
from imageio import imread
from skimage.color import rgba2rgb
from sklearn.metrics.pairwise import euclidean_distances
sys.path.append("../")
from climategan.data import encode_mask_label
from climategan.eval_metrics import edges_coherence_std_min
from eval_masker import crop_and_resize
# -----------------------
# ----- Constants -----
# -----------------------
# Metrics
metrics = ["error", "f05", "edge_coherence"]
dict_metrics = {
"names": {
"tpr": "TPR, Recall, Sensitivity",
"tnr": "TNR, Specificity, Selectivity",
"fpr": "FPR",
"fpt": "False positives relative to image size",
"fnr": "FNR, Miss rate",
"fnt": "False negatives relative to image size",
"mpr": "May positive rate (MPR)",
"mnr": "May negative rate (MNR)",
"accuracy": "Accuracy (ignoring may)",
"error": "Error",
"f05": "F05 score",
"precision": "Precision",
"edge_coherence": "Edge coherence",
"accuracy_must_may": "Accuracy (ignoring cannot)",
},
"key_metrics": ["error", "f05", "edge_coherence"],
}
# Colors
colorblind_palette = sns.color_palette("colorblind")
color_cannot = colorblind_palette[1]
color_must = colorblind_palette[2]
color_may = colorblind_palette[7]
color_pred = colorblind_palette[4]
icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5)
color_tp = icefire[0]
color_tn = icefire[1]
color_fp = icefire[4]
color_fn = icefire[3]
def parsed_args():
"""
Parse and returns command-line args
Returns:
argparse.Namespace: the parsed arguments
"""
parser = ArgumentParser()
parser.add_argument(
"--input_csv",
default="ablations_metrics_20210311.csv",
type=str,
help="CSV containing the results of the ablation study",
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
help="Output directory",
)
parser.add_argument(
"--models_log_path",
default=None,
type=str,
help="Path containing the log files of the models",
)
parser.add_argument(
"--masker_test_set_dir",
default=None,
type=str,
help="Directory containing the test images",
)
parser.add_argument(
"--best_model",
default="dada, msd_spade, pseudo",
type=str,
help="The string identifier of the best model",
)
parser.add_argument(
"--dpi",
default=200,
type=int,
help="DPI for the output images",
)
parser.add_argument(
"--alpha",
default=0.5,
type=float,
help="Transparency of labels shade",
)
parser.add_argument(
"--percentile",
default=0.05,
type=float,
help="Transparency of labels shade",
)
parser.add_argument(
"--seed",
default=None,
type=int,
help="Bootstrap random seed, for reproducibility",
)
parser.add_argument(
"--no_images",
action="store_true",
default=False,
help="Do not generate images",
)
return parser.parse_args()
def map_color(arr, input_color, output_color, rtol=1e-09):
"""
Maps one color to another
"""
input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,)))
output = arr.copy()
output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color
return output
def plot_labels(ax, img, label, img_id, do_legend):
label_colmap = label.astype(float)
label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot)
label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
ax.imshow(img)
ax.imshow(label_colmap, alpha=0.5)
ax.axis("off")
# Annotation
ax.annotate(
xy=(0.05, 0.95),
xycoords="axes fraction",
xytext=(0.05, 0.95),
textcoords="axes fraction",
text=img_id,
fontsize="x-large",
verticalalignment="top",
color="white",
)
# Legend
if do_legend:
handles = []
lw = 1.0
handles.append(
mpatches.Patch(facecolor=color_must, label="must", linewidth=lw, alpha=0.66)
)
handles.append(
mpatches.Patch(facecolor=color_may, label="must", linewidth=lw, alpha=0.66)
)
handles.append(
mpatches.Patch(
facecolor=color_cannot, label="must", linewidth=lw, alpha=0.66
)
)
labels = ["Must-be-flooded", "May-be-flooded", "Cannot-be-flooded"]
ax.legend(
handles=handles,
labels=labels,
bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
ncol=3,
mode="expand",
fontsize="xx-small",
frameon=False,
)
def plot_pred(ax, img, pred, img_id, do_legend):
pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
pred_colmap = pred.astype(float)
pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
pred_colmap_ma = pred_colmap_ma.mask * img + pred_colmap_ma
ax.imshow(img)
ax.imshow(pred_colmap_ma, alpha=0.5)
ax.axis("off")
# Annotation
ax.annotate(
xy=(0.05, 0.95),
xycoords="axes fraction",
xytext=(0.05, 0.95),
textcoords="axes fraction",
text=img_id,
fontsize="x-large",
verticalalignment="top",
color="white",
)
# Legend
if do_legend:
handles = []
lw = 1.0
handles.append(
mpatches.Patch(facecolor=color_pred, label="must", linewidth=lw, alpha=0.66)
)
labels = ["Prediction"]
ax.legend(
handles=handles,
labels=labels,
bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
ncol=3,
mode="expand",
fontsize="xx-small",
frameon=False,
)
def plot_correct_incorrect(ax, img_filename, img, label, img_id, do_legend):
# FP
fp_map = imread(
model_path / "eval-metrics/fp" / "{}_fp.png".format(Path(img_filename).stem)
)
fp_map = np.tile(np.expand_dims(fp_map, axis=2), reps=(1, 1, 3))
fp_map_colmap = fp_map.astype(float)
fp_map_colmap = map_color(fp_map_colmap, (1, 1, 1), color_fp)
# FN
fn_map = imread(
model_path / "eval-metrics/fn" / "{}_fn.png".format(Path(img_filename).stem)
)
fn_map = np.tile(np.expand_dims(fn_map, axis=2), reps=(1, 1, 3))
fn_map_colmap = fn_map.astype(float)
fn_map_colmap = map_color(fn_map_colmap, (1, 1, 1), color_fn)
# TP
tp_map = imread(
model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(img_filename).stem)
)
tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
tp_map_colmap = tp_map.astype(float)
tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
# TN
tn_map = imread(
model_path / "eval-metrics/tn" / "{}_tn.png".format(Path(img_filename).stem)
)
tn_map = np.tile(np.expand_dims(tn_map, axis=2), reps=(1, 1, 3))
tn_map_colmap = tn_map.astype(float)
tn_map_colmap = map_color(tn_map_colmap, (1, 1, 1), color_tn)
label_colmap = label.astype(float)
label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_may)
label_colmap_ma = label_colmap_ma.mask * img + label_colmap_ma
# Combine masks
maps = fp_map_colmap + fn_map_colmap + tp_map_colmap + tn_map_colmap
maps_ma = np.ma.masked_equal(maps, (0, 0, 0))
maps_ma = maps_ma.mask * img + maps_ma
ax.imshow(img)
ax.imshow(label_colmap_ma, alpha=0.5)
ax.imshow(maps_ma, alpha=0.5)
ax.axis("off")
# Annotation
ax.annotate(
xy=(0.05, 0.95),
xycoords="axes fraction",
xytext=(0.05, 0.95),
textcoords="axes fraction",
text=img_id,
fontsize="x-large",
verticalalignment="top",
color="white",
)
# Legend
if do_legend:
handles = []
lw = 1.0
handles.append(
mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
)
handles.append(
mpatches.Patch(facecolor=color_tn, label="TN", linewidth=lw, alpha=0.66)
)
handles.append(
mpatches.Patch(facecolor=color_fp, label="FP", linewidth=lw, alpha=0.66)
)
handles.append(
mpatches.Patch(facecolor=color_fn, label="FN", linewidth=lw, alpha=0.66)
)
handles.append(
mpatches.Patch(
facecolor=color_may, label="May-be-flooded", linewidth=lw, alpha=0.66
)
)
labels = ["TP", "TN", "FP", "FN", "May-be-flooded"]
ax.legend(
handles=handles,
labels=labels,
bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
ncol=5,
mode="expand",
fontsize="xx-small",
frameon=False,
)
def plot_edge_coherence(ax, img, label, pred, img_id, do_legend):
pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
ec, pred_ec, label_ec = edges_coherence_std_min(
np.squeeze(pred[:, :, 0]), np.squeeze(encode_mask_label(label, "flood"))
)
##################
# Edge distances #
##################
# Location of edges
pred_ec_coord = np.argwhere(pred_ec > 0)
label_ec_coord = np.argwhere(label_ec > 0)
# Normalized pairwise distances between pred and label
dist_mat = np.divide(
euclidean_distances(pred_ec_coord, label_ec_coord), pred_ec.shape[0]
)
# Standard deviation of the minimum distance from pred to label
min_dist = np.min(dist_mat, axis=1) # noqa: F841
#############
# Make plot #
#############
pred_ec = np.tile(
np.expand_dims(np.asarray(pred_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
)
pred_ec_colmap = map_color(pred_ec, (1, 1, 1), color_pred)
pred_ec_colmap_ma = np.ma.masked_not_equal(pred_ec_colmap, color_pred) # noqa: F841
label_ec = np.tile(
np.expand_dims(np.asarray(label_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
)
label_ec_colmap = map_color(label_ec, (1, 1, 1), color_must)
label_ec_colmap_ma = np.ma.masked_not_equal( # noqa: F841
label_ec_colmap, color_must
)
# Combined pred and label edges
combined_ec = pred_ec_colmap + label_ec_colmap
combined_ec_ma = np.ma.masked_equal(combined_ec, (0, 0, 0))
combined_ec_img = combined_ec_ma.mask * img + combined_ec
# Pred
pred_colmap = pred.astype(float)
pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
# Must
label_colmap = label.astype(float)
label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_must)
# TP
tp_map = imread(
model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(srs_sel.filename).stem)
)
tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
tp_map_colmap = tp_map.astype(float)
tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
tp_map_colmap_ma = np.ma.masked_not_equal(tp_map_colmap, color_tp)
# Combination
comb_pred = (
(pred_colmap_ma.mask ^ tp_map_colmap_ma.mask)
& tp_map_colmap_ma.mask
& combined_ec_ma.mask
) * pred_colmap
comb_label = (
(label_colmap_ma.mask ^ pred_colmap_ma.mask)
& pred_colmap_ma.mask
& combined_ec_ma.mask
) * label_colmap
comb_tp = combined_ec_ma.mask * tp_map_colmap.copy()
combined = comb_tp + comb_label + comb_pred
combined_ma = np.ma.masked_equal(combined, (0, 0, 0))
combined_ma = combined_ma.mask * combined_ec_img + combined_ma
ax.imshow(combined_ec_img, alpha=1)
ax.imshow(combined_ma, alpha=0.5)
ax.axis("off")
# Plot lines
idx_sort_x = np.argsort(pred_ec_coord[:, 1])
offset = 100
for idx in range(offset, pred_ec_coord.shape[0], offset):
y0, x0 = pred_ec_coord[idx_sort_x[idx], :]
argmin = np.argmin(dist_mat[idx_sort_x[idx]])
y1, x1 = label_ec_coord[argmin, :]
ax.plot([x0, x1], [y0, y1], color="white", linewidth=0.5)
# Annotation
ax.annotate(
xy=(0.05, 0.95),
xycoords="axes fraction",
xytext=(0.05, 0.95),
textcoords="axes fraction",
text=img_id,
fontsize="x-large",
verticalalignment="top",
color="white",
)
# Legend
if do_legend:
handles = []
lw = 1.0
handles.append(
mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
)
handles.append(
mpatches.Patch(facecolor=color_pred, label="pred", linewidth=lw, alpha=0.66)
)
handles.append(
mpatches.Patch(
facecolor=color_must, label="Must-be-flooded", linewidth=lw, alpha=0.66
)
)
labels = ["TP", "Prediction", "Must-be-flooded"]
ax.legend(
handles=handles,
labels=labels,
bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
ncol=3,
mode="expand",
fontsize="xx-small",
frameon=False,
)
def plot_images_metric(axes, metric, img_filename, img_id, do_legend):
# Read images
img_path = imgs_orig_path / img_filename
label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem)
img, label = crop_and_resize(img_path, label_path)
img = rgba2rgb(img) if img.shape[-1] == 4 else img / 255.0
pred = imread(
model_path / "eval-metrics/pred" / "{}_pred.png".format(Path(img_filename).stem)
)
# Label
plot_labels(axes[0], img, label, img_id, do_legend)
# Prediction
plot_pred(axes[1], img, pred, img_id, do_legend)
# Correct / incorrect
if metric in ["error", "f05"]:
plot_correct_incorrect(axes[2], img_filename, img, label, img_id, do_legend)
# Edge coherence
elif metric == "edge_coherence":
plot_edge_coherence(axes[2], img, label, pred, img_id, do_legend)
else:
raise ValueError
def scatterplot_metrics_pair(ax, df, x_metric, y_metric, dict_images):
sns.scatterplot(data=df, x=x_metric, y=y_metric, ax=ax)
# Set X-label
ax.set_xlabel(dict_metrics["names"][x_metric], rotation=0, fontsize="medium")
# Set Y-label
ax.set_ylabel(dict_metrics["names"][y_metric], rotation=90, fontsize="medium")
# Change spines
sns.despine(ax=ax, left=True, bottom=True)
annotate_scatterplot(ax, dict_images, x_metric, y_metric)
def scatterplot_metrics(ax, df, dict_images):
sns.scatterplot(data=df, x="error", y="f05", hue="edge_coherence", ax=ax)
# Set X-label
ax.set_xlabel(dict_metrics["names"]["error"], rotation=0, fontsize="medium")
# Set Y-label
ax.set_ylabel(dict_metrics["names"]["f05"], rotation=90, fontsize="medium")
annotate_scatterplot(ax, dict_images, "error", "f05")
# Change spines
sns.despine(ax=ax, left=True, bottom=True)
# Set XY limits
xlim = ax.get_xlim()
ylim = ax.get_ylim()
ax.set_xlim([0.0, xlim[1]])
ax.set_ylim([ylim[0], 1.0])
def annotate_scatterplot(ax, dict_images, x_metric, y_metric, offset=0.1):
xlim = ax.get_xlim()
ylim = ax.get_ylim()
x_len = xlim[1] - xlim[0]
y_len = ylim[1] - ylim[0]
x_th = xlim[1] - x_len / 2.0
y_th = ylim[1] - y_len / 2.0
for text, d in dict_images.items():
x = d[x_metric]
y = d[y_metric]
x_text = x + x_len * offset if x < x_th else x - x_len * offset
y_text = y + y_len * offset if y < y_th else y - y_len * offset
ax.annotate(
xy=(x, y),
xycoords="data",
xytext=(x_text, y_text),
textcoords="data",
text=text,
arrowprops=dict(facecolor="black", shrink=0.05),
fontsize="medium",
color="black",
)
if __name__ == "__main__":
# -----------------------------
# ----- Parse arguments -----
# -----------------------------
args = parsed_args()
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
# Determine output dir
if args.output_dir is None:
output_dir = Path(os.environ["SLURM_TMPDIR"])
else:
output_dir = Path(args.output_dir)
if not output_dir.exists():
output_dir.mkdir(parents=True, exist_ok=False)
# Store args
output_yml = output_dir / "labels.yml"
with open(output_yml, "w") as f:
yaml.dump(vars(args), f)
# Data dirs
imgs_orig_path = Path(args.masker_test_set_dir) / "imgs"
labels_path = Path(args.masker_test_set_dir) / "labels"
# Read CSV
df = pd.read_csv(args.input_csv, index_col="model_img_idx")
# Select best model
df = df.loc[df.model_feats == args.best_model]
v_key, model_dir = df.model.unique()[0].split("/")
model_path = Path(args.models_log_path) / "ablation-{}".format(v_key) / model_dir
# Set up plot
sns.reset_orig()
sns.set(style="whitegrid")
plt.rcParams.update({"font.family": "serif"})
plt.rcParams.update(
{
"font.serif": [
"Computer Modern Roman",
"Times New Roman",
"Utopia",
"New Century Schoolbook",
"Century Schoolbook L",
"ITC Bookman",
"Bookman",
"Times",
"Palatino",
"Charter",
"serif" "Bitstream Vera Serif",
"DejaVu Serif",
]
}
)
if args.seed:
np.random.seed(args.seed)
img_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
dict_images = {}
idx = 0
for metric in metrics:
fig, axes = plt.subplots(nrows=2, ncols=3, dpi=200, figsize=(18, 12))
# Select best
if metric == "error":
ascending = True
else:
ascending = False
idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
img_id = img_ids[idx]
dict_images.update({img_id: srs_sel})
# Read images
img_filename = srs_sel.filename
if not args.no_images:
axes_row = axes[0, :]
plot_images_metric(axes_row, metric, img_filename, img_id, do_legend=True)
idx += 1
# Select worst
if metric == "error":
ascending = False
else:
ascending = True
idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
img_id = img_ids[idx]
dict_images.update({img_id: srs_sel})
# Read images
img_filename = srs_sel.filename
if not args.no_images:
axes_row = axes[1, :]
plot_images_metric(axes_row, metric, img_filename, img_id, do_legend=False)
idx += 1
# Save figure
output_fig = output_dir / "{}.png".format(metric)
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
fig = plt.figure(dpi=200)
scatterplot_metrics(fig.gca(), df, dict_images)
# fig, axes = plt.subplots(nrows=1, ncols=3, dpi=200, figsize=(18, 5))
#
# scatterplot_metrics_pair(axes[0], df, 'error', 'f05', dict_images)
# scatterplot_metrics_pair(axes[1], df, 'error', 'edge_coherence', dict_images)
# scatterplot_metrics_pair(axes[2], df, 'f05', 'edge_coherence', dict_images)
#
output_fig = output_dir / "scatterplots.png"
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")