""" This scripts plots examples of the images that get best and worse metrics """ print("Imports...", end="") import os import sys from argparse import ArgumentParser from pathlib import Path import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns import yaml from imageio import imread from skimage.color import rgba2rgb from sklearn.metrics.pairwise import euclidean_distances sys.path.append("../") from climategan.data import encode_mask_label from climategan.eval_metrics import edges_coherence_std_min from eval_masker import crop_and_resize # ----------------------- # ----- Constants ----- # ----------------------- # Metrics metrics = ["error", "f05", "edge_coherence"] dict_metrics = { "names": { "tpr": "TPR, Recall, Sensitivity", "tnr": "TNR, Specificity, Selectivity", "fpr": "FPR", "fpt": "False positives relative to image size", "fnr": "FNR, Miss rate", "fnt": "False negatives relative to image size", "mpr": "May positive rate (MPR)", "mnr": "May negative rate (MNR)", "accuracy": "Accuracy (ignoring may)", "error": "Error", "f05": "F05 score", "precision": "Precision", "edge_coherence": "Edge coherence", "accuracy_must_may": "Accuracy (ignoring cannot)", }, "key_metrics": ["error", "f05", "edge_coherence"], } # Colors colorblind_palette = sns.color_palette("colorblind") color_cannot = colorblind_palette[1] color_must = colorblind_palette[2] color_may = colorblind_palette[7] color_pred = colorblind_palette[4] icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5) color_tp = icefire[0] color_tn = icefire[1] color_fp = icefire[4] color_fn = icefire[3] def parsed_args(): """ Parse and returns command-line args Returns: argparse.Namespace: the parsed arguments """ parser = ArgumentParser() parser.add_argument( "--input_csv", default="ablations_metrics_20210311.csv", type=str, help="CSV containing the results of the ablation study", ) parser.add_argument( "--output_dir", default=None, type=str, help="Output directory", ) parser.add_argument( "--models_log_path", default=None, type=str, help="Path containing the log files of the models", ) parser.add_argument( "--masker_test_set_dir", default=None, type=str, help="Directory containing the test images", ) parser.add_argument( "--best_model", default="dada, msd_spade, pseudo", type=str, help="The string identifier of the best model", ) parser.add_argument( "--dpi", default=200, type=int, help="DPI for the output images", ) parser.add_argument( "--alpha", default=0.5, type=float, help="Transparency of labels shade", ) parser.add_argument( "--percentile", default=0.05, type=float, help="Transparency of labels shade", ) parser.add_argument( "--seed", default=None, type=int, help="Bootstrap random seed, for reproducibility", ) parser.add_argument( "--no_images", action="store_true", default=False, help="Do not generate images", ) return parser.parse_args() def map_color(arr, input_color, output_color, rtol=1e-09): """ Maps one color to another """ input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,))) output = arr.copy() output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color return output def plot_labels(ax, img, label, img_id, do_legend): label_colmap = label.astype(float) label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot) label_colmap = map_color(label_colmap, (0, 0, 255), color_must) label_colmap = map_color(label_colmap, (0, 0, 0), color_may) ax.imshow(img) ax.imshow(label_colmap, alpha=0.5) ax.axis("off") # Annotation ax.annotate( xy=(0.05, 0.95), xycoords="axes fraction", xytext=(0.05, 0.95), textcoords="axes fraction", text=img_id, fontsize="x-large", verticalalignment="top", color="white", ) # Legend if do_legend: handles = [] lw = 1.0 handles.append( mpatches.Patch(facecolor=color_must, label="must", linewidth=lw, alpha=0.66) ) handles.append( mpatches.Patch(facecolor=color_may, label="must", linewidth=lw, alpha=0.66) ) handles.append( mpatches.Patch( facecolor=color_cannot, label="must", linewidth=lw, alpha=0.66 ) ) labels = ["Must-be-flooded", "May-be-flooded", "Cannot-be-flooded"] ax.legend( handles=handles, labels=labels, bbox_to_anchor=(0.0, 1.0, 1.0, 0.075), ncol=3, mode="expand", fontsize="xx-small", frameon=False, ) def plot_pred(ax, img, pred, img_id, do_legend): pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3)) pred_colmap = pred.astype(float) pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred) pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred) pred_colmap_ma = pred_colmap_ma.mask * img + pred_colmap_ma ax.imshow(img) ax.imshow(pred_colmap_ma, alpha=0.5) ax.axis("off") # Annotation ax.annotate( xy=(0.05, 0.95), xycoords="axes fraction", xytext=(0.05, 0.95), textcoords="axes fraction", text=img_id, fontsize="x-large", verticalalignment="top", color="white", ) # Legend if do_legend: handles = [] lw = 1.0 handles.append( mpatches.Patch(facecolor=color_pred, label="must", linewidth=lw, alpha=0.66) ) labels = ["Prediction"] ax.legend( handles=handles, labels=labels, bbox_to_anchor=(0.0, 1.0, 1.0, 0.075), ncol=3, mode="expand", fontsize="xx-small", frameon=False, ) def plot_correct_incorrect(ax, img_filename, img, label, img_id, do_legend): # FP fp_map = imread( model_path / "eval-metrics/fp" / "{}_fp.png".format(Path(img_filename).stem) ) fp_map = np.tile(np.expand_dims(fp_map, axis=2), reps=(1, 1, 3)) fp_map_colmap = fp_map.astype(float) fp_map_colmap = map_color(fp_map_colmap, (1, 1, 1), color_fp) # FN fn_map = imread( model_path / "eval-metrics/fn" / "{}_fn.png".format(Path(img_filename).stem) ) fn_map = np.tile(np.expand_dims(fn_map, axis=2), reps=(1, 1, 3)) fn_map_colmap = fn_map.astype(float) fn_map_colmap = map_color(fn_map_colmap, (1, 1, 1), color_fn) # TP tp_map = imread( model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(img_filename).stem) ) tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3)) tp_map_colmap = tp_map.astype(float) tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp) # TN tn_map = imread( model_path / "eval-metrics/tn" / "{}_tn.png".format(Path(img_filename).stem) ) tn_map = np.tile(np.expand_dims(tn_map, axis=2), reps=(1, 1, 3)) tn_map_colmap = tn_map.astype(float) tn_map_colmap = map_color(tn_map_colmap, (1, 1, 1), color_tn) label_colmap = label.astype(float) label_colmap = map_color(label_colmap, (0, 0, 0), color_may) label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_may) label_colmap_ma = label_colmap_ma.mask * img + label_colmap_ma # Combine masks maps = fp_map_colmap + fn_map_colmap + tp_map_colmap + tn_map_colmap maps_ma = np.ma.masked_equal(maps, (0, 0, 0)) maps_ma = maps_ma.mask * img + maps_ma ax.imshow(img) ax.imshow(label_colmap_ma, alpha=0.5) ax.imshow(maps_ma, alpha=0.5) ax.axis("off") # Annotation ax.annotate( xy=(0.05, 0.95), xycoords="axes fraction", xytext=(0.05, 0.95), textcoords="axes fraction", text=img_id, fontsize="x-large", verticalalignment="top", color="white", ) # Legend if do_legend: handles = [] lw = 1.0 handles.append( mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66) ) handles.append( mpatches.Patch(facecolor=color_tn, label="TN", linewidth=lw, alpha=0.66) ) handles.append( mpatches.Patch(facecolor=color_fp, label="FP", linewidth=lw, alpha=0.66) ) handles.append( mpatches.Patch(facecolor=color_fn, label="FN", linewidth=lw, alpha=0.66) ) handles.append( mpatches.Patch( facecolor=color_may, label="May-be-flooded", linewidth=lw, alpha=0.66 ) ) labels = ["TP", "TN", "FP", "FN", "May-be-flooded"] ax.legend( handles=handles, labels=labels, bbox_to_anchor=(0.0, 1.0, 1.0, 0.075), ncol=5, mode="expand", fontsize="xx-small", frameon=False, ) def plot_edge_coherence(ax, img, label, pred, img_id, do_legend): pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3)) ec, pred_ec, label_ec = edges_coherence_std_min( np.squeeze(pred[:, :, 0]), np.squeeze(encode_mask_label(label, "flood")) ) ################## # Edge distances # ################## # Location of edges pred_ec_coord = np.argwhere(pred_ec > 0) label_ec_coord = np.argwhere(label_ec > 0) # Normalized pairwise distances between pred and label dist_mat = np.divide( euclidean_distances(pred_ec_coord, label_ec_coord), pred_ec.shape[0] ) # Standard deviation of the minimum distance from pred to label min_dist = np.min(dist_mat, axis=1) # noqa: F841 ############# # Make plot # ############# pred_ec = np.tile( np.expand_dims(np.asarray(pred_ec > 0, dtype=float), axis=2), reps=(1, 1, 3) ) pred_ec_colmap = map_color(pred_ec, (1, 1, 1), color_pred) pred_ec_colmap_ma = np.ma.masked_not_equal(pred_ec_colmap, color_pred) # noqa: F841 label_ec = np.tile( np.expand_dims(np.asarray(label_ec > 0, dtype=float), axis=2), reps=(1, 1, 3) ) label_ec_colmap = map_color(label_ec, (1, 1, 1), color_must) label_ec_colmap_ma = np.ma.masked_not_equal( # noqa: F841 label_ec_colmap, color_must ) # Combined pred and label edges combined_ec = pred_ec_colmap + label_ec_colmap combined_ec_ma = np.ma.masked_equal(combined_ec, (0, 0, 0)) combined_ec_img = combined_ec_ma.mask * img + combined_ec # Pred pred_colmap = pred.astype(float) pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred) pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred) # Must label_colmap = label.astype(float) label_colmap = map_color(label_colmap, (0, 0, 255), color_must) label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_must) # TP tp_map = imread( model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(srs_sel.filename).stem) ) tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3)) tp_map_colmap = tp_map.astype(float) tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp) tp_map_colmap_ma = np.ma.masked_not_equal(tp_map_colmap, color_tp) # Combination comb_pred = ( (pred_colmap_ma.mask ^ tp_map_colmap_ma.mask) & tp_map_colmap_ma.mask & combined_ec_ma.mask ) * pred_colmap comb_label = ( (label_colmap_ma.mask ^ pred_colmap_ma.mask) & pred_colmap_ma.mask & combined_ec_ma.mask ) * label_colmap comb_tp = combined_ec_ma.mask * tp_map_colmap.copy() combined = comb_tp + comb_label + comb_pred combined_ma = np.ma.masked_equal(combined, (0, 0, 0)) combined_ma = combined_ma.mask * combined_ec_img + combined_ma ax.imshow(combined_ec_img, alpha=1) ax.imshow(combined_ma, alpha=0.5) ax.axis("off") # Plot lines idx_sort_x = np.argsort(pred_ec_coord[:, 1]) offset = 100 for idx in range(offset, pred_ec_coord.shape[0], offset): y0, x0 = pred_ec_coord[idx_sort_x[idx], :] argmin = np.argmin(dist_mat[idx_sort_x[idx]]) y1, x1 = label_ec_coord[argmin, :] ax.plot([x0, x1], [y0, y1], color="white", linewidth=0.5) # Annotation ax.annotate( xy=(0.05, 0.95), xycoords="axes fraction", xytext=(0.05, 0.95), textcoords="axes fraction", text=img_id, fontsize="x-large", verticalalignment="top", color="white", ) # Legend if do_legend: handles = [] lw = 1.0 handles.append( mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66) ) handles.append( mpatches.Patch(facecolor=color_pred, label="pred", linewidth=lw, alpha=0.66) ) handles.append( mpatches.Patch( facecolor=color_must, label="Must-be-flooded", linewidth=lw, alpha=0.66 ) ) labels = ["TP", "Prediction", "Must-be-flooded"] ax.legend( handles=handles, labels=labels, bbox_to_anchor=(0.0, 1.0, 1.0, 0.075), ncol=3, mode="expand", fontsize="xx-small", frameon=False, ) def plot_images_metric(axes, metric, img_filename, img_id, do_legend): # Read images img_path = imgs_orig_path / img_filename label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem) img, label = crop_and_resize(img_path, label_path) img = rgba2rgb(img) if img.shape[-1] == 4 else img / 255.0 pred = imread( model_path / "eval-metrics/pred" / "{}_pred.png".format(Path(img_filename).stem) ) # Label plot_labels(axes[0], img, label, img_id, do_legend) # Prediction plot_pred(axes[1], img, pred, img_id, do_legend) # Correct / incorrect if metric in ["error", "f05"]: plot_correct_incorrect(axes[2], img_filename, img, label, img_id, do_legend) # Edge coherence elif metric == "edge_coherence": plot_edge_coherence(axes[2], img, label, pred, img_id, do_legend) else: raise ValueError def scatterplot_metrics_pair(ax, df, x_metric, y_metric, dict_images): sns.scatterplot(data=df, x=x_metric, y=y_metric, ax=ax) # Set X-label ax.set_xlabel(dict_metrics["names"][x_metric], rotation=0, fontsize="medium") # Set Y-label ax.set_ylabel(dict_metrics["names"][y_metric], rotation=90, fontsize="medium") # Change spines sns.despine(ax=ax, left=True, bottom=True) annotate_scatterplot(ax, dict_images, x_metric, y_metric) def scatterplot_metrics(ax, df, dict_images): sns.scatterplot(data=df, x="error", y="f05", hue="edge_coherence", ax=ax) # Set X-label ax.set_xlabel(dict_metrics["names"]["error"], rotation=0, fontsize="medium") # Set Y-label ax.set_ylabel(dict_metrics["names"]["f05"], rotation=90, fontsize="medium") annotate_scatterplot(ax, dict_images, "error", "f05") # Change spines sns.despine(ax=ax, left=True, bottom=True) # Set XY limits xlim = ax.get_xlim() ylim = ax.get_ylim() ax.set_xlim([0.0, xlim[1]]) ax.set_ylim([ylim[0], 1.0]) def annotate_scatterplot(ax, dict_images, x_metric, y_metric, offset=0.1): xlim = ax.get_xlim() ylim = ax.get_ylim() x_len = xlim[1] - xlim[0] y_len = ylim[1] - ylim[0] x_th = xlim[1] - x_len / 2.0 y_th = ylim[1] - y_len / 2.0 for text, d in dict_images.items(): x = d[x_metric] y = d[y_metric] x_text = x + x_len * offset if x < x_th else x - x_len * offset y_text = y + y_len * offset if y < y_th else y - y_len * offset ax.annotate( xy=(x, y), xycoords="data", xytext=(x_text, y_text), textcoords="data", text=text, arrowprops=dict(facecolor="black", shrink=0.05), fontsize="medium", color="black", ) if __name__ == "__main__": # ----------------------------- # ----- Parse arguments ----- # ----------------------------- args = parsed_args() print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) # Determine output dir if args.output_dir is None: output_dir = Path(os.environ["SLURM_TMPDIR"]) else: output_dir = Path(args.output_dir) if not output_dir.exists(): output_dir.mkdir(parents=True, exist_ok=False) # Store args output_yml = output_dir / "labels.yml" with open(output_yml, "w") as f: yaml.dump(vars(args), f) # Data dirs imgs_orig_path = Path(args.masker_test_set_dir) / "imgs" labels_path = Path(args.masker_test_set_dir) / "labels" # Read CSV df = pd.read_csv(args.input_csv, index_col="model_img_idx") # Select best model df = df.loc[df.model_feats == args.best_model] v_key, model_dir = df.model.unique()[0].split("/") model_path = Path(args.models_log_path) / "ablation-{}".format(v_key) / model_dir # Set up plot sns.reset_orig() sns.set(style="whitegrid") plt.rcParams.update({"font.family": "serif"}) plt.rcParams.update( { "font.serif": [ "Computer Modern Roman", "Times New Roman", "Utopia", "New Century Schoolbook", "Century Schoolbook L", "ITC Bookman", "Bookman", "Times", "Palatino", "Charter", "serif" "Bitstream Vera Serif", "DejaVu Serif", ] } ) if args.seed: np.random.seed(args.seed) img_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" dict_images = {} idx = 0 for metric in metrics: fig, axes = plt.subplots(nrows=2, ncols=3, dpi=200, figsize=(18, 12)) # Select best if metric == "error": ascending = True else: ascending = False idx_rand = np.random.permutation(int(args.percentile * len(df)))[0] srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand] img_id = img_ids[idx] dict_images.update({img_id: srs_sel}) # Read images img_filename = srs_sel.filename if not args.no_images: axes_row = axes[0, :] plot_images_metric(axes_row, metric, img_filename, img_id, do_legend=True) idx += 1 # Select worst if metric == "error": ascending = False else: ascending = True idx_rand = np.random.permutation(int(args.percentile * len(df)))[0] srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand] img_id = img_ids[idx] dict_images.update({img_id: srs_sel}) # Read images img_filename = srs_sel.filename if not args.no_images: axes_row = axes[1, :] plot_images_metric(axes_row, metric, img_filename, img_id, do_legend=False) idx += 1 # Save figure output_fig = output_dir / "{}.png".format(metric) fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") fig = plt.figure(dpi=200) scatterplot_metrics(fig.gca(), df, dict_images) # fig, axes = plt.subplots(nrows=1, ncols=3, dpi=200, figsize=(18, 5)) # # scatterplot_metrics_pair(axes[0], df, 'error', 'f05', dict_images) # scatterplot_metrics_pair(axes[1], df, 'error', 'edge_coherence', dict_images) # scatterplot_metrics_pair(axes[2], df, 'f05', 'edge_coherence', dict_images) # output_fig = output_dir / "scatterplots.png" fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")