Spaces:
Runtime error
Runtime error
""" | |
This scripts plots examples of the images that get best and worse metrics | |
""" | |
print("Imports...", end="") | |
import os | |
import sys | |
from argparse import ArgumentParser | |
from pathlib import Path | |
import matplotlib.patches as mpatches | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
import yaml | |
from imageio import imread | |
from matplotlib.gridspec import GridSpec | |
from skimage.color import rgba2rgb | |
from sklearn.metrics.pairwise import euclidean_distances | |
sys.path.append("../") | |
from climategan.data import encode_mask_label | |
from climategan.eval_metrics import edges_coherence_std_min | |
from eval_masker import crop_and_resize | |
# ----------------------- | |
# ----- Constants ----- | |
# ----------------------- | |
# Metrics | |
metrics = ["error", "f05", "edge_coherence"] | |
dict_metrics = { | |
"names": { | |
"tpr": "TPR, Recall, Sensitivity", | |
"tnr": "TNR, Specificity, Selectivity", | |
"fpr": "FPR", | |
"fpt": "False positives relative to image size", | |
"fnr": "FNR, Miss rate", | |
"fnt": "False negatives relative to image size", | |
"mpr": "May positive rate (MPR)", | |
"mnr": "May negative rate (MNR)", | |
"accuracy": "Accuracy (ignoring may)", | |
"error": "Error", | |
"f05": "F05 score", | |
"precision": "Precision", | |
"edge_coherence": "Edge coherence", | |
"accuracy_must_may": "Accuracy (ignoring cannot)", | |
}, | |
"key_metrics": ["error", "f05", "edge_coherence"], | |
} | |
# Colors | |
colorblind_palette = sns.color_palette("colorblind") | |
color_cannot = colorblind_palette[1] | |
color_must = colorblind_palette[2] | |
color_may = colorblind_palette[7] | |
color_pred = colorblind_palette[4] | |
icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5) | |
color_tp = icefire[0] | |
color_tn = icefire[1] | |
color_fp = icefire[4] | |
color_fn = icefire[3] | |
def parsed_args(): | |
""" | |
Parse and returns command-line args | |
Returns: | |
argparse.Namespace: the parsed arguments | |
""" | |
parser = ArgumentParser() | |
parser.add_argument( | |
"--input_csv", | |
default="ablations_metrics_20210311.csv", | |
type=str, | |
help="CSV containing the results of the ablation study", | |
) | |
parser.add_argument( | |
"--output_dir", | |
default=None, | |
type=str, | |
help="Output directory", | |
) | |
parser.add_argument( | |
"--models_log_path", | |
default=None, | |
type=str, | |
help="Path containing the log files of the models", | |
) | |
parser.add_argument( | |
"--masker_test_set_dir", | |
default=None, | |
type=str, | |
help="Directory containing the test images", | |
) | |
parser.add_argument( | |
"--best_model", | |
default="dada, msd_spade, pseudo", | |
type=str, | |
help="The string identifier of the best model", | |
) | |
parser.add_argument( | |
"--dpi", | |
default=200, | |
type=int, | |
help="DPI for the output images", | |
) | |
parser.add_argument( | |
"--alpha", | |
default=0.5, | |
type=float, | |
help="Transparency of labels shade", | |
) | |
parser.add_argument( | |
"--percentile", | |
default=0.05, | |
type=float, | |
help="Transparency of labels shade", | |
) | |
parser.add_argument( | |
"--seed", | |
default=None, | |
type=int, | |
help="Bootstrap random seed, for reproducibility", | |
) | |
parser.add_argument( | |
"--no_images", | |
action="store_true", | |
default=False, | |
help="Do not generate images", | |
) | |
return parser.parse_args() | |
def map_color(arr, input_color, output_color, rtol=1e-09): | |
""" | |
Maps one color to another | |
""" | |
input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,))) | |
output = arr.copy() | |
output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color | |
return output | |
def plot_labels(ax, img, label, img_id, n_, add_title, do_legend): | |
label_colmap = label.astype(float) | |
label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot) | |
label_colmap = map_color(label_colmap, (0, 0, 255), color_must) | |
label_colmap = map_color(label_colmap, (0, 0, 0), color_may) | |
ax.imshow(img) | |
ax.imshow(label_colmap, alpha=0.5) | |
ax.axis("off") | |
if n_ in [1, 3, 5]: | |
color_ = "green" | |
else: | |
color_ = "red" | |
ax.text( | |
-0.15, | |
0.5, | |
img_id, | |
color=color_, | |
fontweight="roman", | |
fontsize="x-large", | |
horizontalalignment="left", | |
verticalalignment="center", | |
transform=ax.transAxes, | |
) | |
if add_title: | |
ax.set_title("Labels", rotation=0, fontsize="x-large") | |
def plot_pred(ax, img, pred, img_id, add_title, do_legend): | |
pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3)) | |
pred_colmap = pred.astype(float) | |
pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred) | |
pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred) | |
pred_colmap_ma = pred_colmap_ma.mask * img + pred_colmap_ma | |
ax.imshow(img) | |
ax.imshow(pred_colmap_ma, alpha=0.5) | |
ax.axis("off") | |
if add_title: | |
ax.set_title("Prediction", rotation=0, fontsize="x-large") | |
def plot_correct_incorrect( | |
ax, img_filename, img, metric, label, img_id, n_, add_title, do_legend | |
): | |
# FP | |
fp_map = imread( | |
model_path / "eval-metrics/fp" / "{}_fp.png".format(Path(img_filename).stem) | |
) | |
fp_map = np.tile(np.expand_dims(fp_map, axis=2), reps=(1, 1, 3)) | |
fp_map_colmap = fp_map.astype(float) | |
fp_map_colmap = map_color(fp_map_colmap, (1, 1, 1), color_fp) | |
# FN | |
fn_map = imread( | |
model_path / "eval-metrics/fn" / "{}_fn.png".format(Path(img_filename).stem) | |
) | |
fn_map = np.tile(np.expand_dims(fn_map, axis=2), reps=(1, 1, 3)) | |
fn_map_colmap = fn_map.astype(float) | |
fn_map_colmap = map_color(fn_map_colmap, (1, 1, 1), color_fn) | |
# TP | |
tp_map = imread( | |
model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(img_filename).stem) | |
) | |
tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3)) | |
tp_map_colmap = tp_map.astype(float) | |
tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp) | |
# TN | |
tn_map = imread( | |
model_path / "eval-metrics/tn" / "{}_tn.png".format(Path(img_filename).stem) | |
) | |
tn_map = np.tile(np.expand_dims(tn_map, axis=2), reps=(1, 1, 3)) | |
tn_map_colmap = tn_map.astype(float) | |
tn_map_colmap = map_color(tn_map_colmap, (1, 1, 1), color_tn) | |
label_colmap = label.astype(float) | |
label_colmap = map_color(label_colmap, (0, 0, 0), color_may) | |
label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_may) | |
label_colmap_ma = label_colmap_ma.mask * img + label_colmap_ma | |
# Combine masks | |
maps = fp_map_colmap + fn_map_colmap + tp_map_colmap + tn_map_colmap | |
maps_ma = np.ma.masked_equal(maps, (0, 0, 0)) | |
maps_ma = maps_ma.mask * img + maps_ma | |
ax.imshow(img) | |
ax.imshow(label_colmap_ma, alpha=0.5) | |
ax.imshow(maps_ma, alpha=0.5) | |
ax.axis("off") | |
if add_title: | |
ax.set_title("Metric", rotation=0, fontsize="x-large") | |
def plot_edge_coherence(ax, img, metric, label, pred, img_id, n_, add_title, do_legend): | |
pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3)) | |
ec, pred_ec, label_ec = edges_coherence_std_min( | |
np.squeeze(pred[:, :, 0]), np.squeeze(encode_mask_label(label, "flood")) | |
) | |
################## | |
# Edge distances # | |
################## | |
# Location of edges | |
pred_ec_coord = np.argwhere(pred_ec > 0) | |
label_ec_coord = np.argwhere(label_ec > 0) | |
# Normalized pairwise distances between pred and label | |
dist_mat = np.divide( | |
euclidean_distances(pred_ec_coord, label_ec_coord), pred_ec.shape[0] | |
) | |
# Standard deviation of the minimum distance from pred to label | |
min_dist = np.min(dist_mat, axis=1) # noqa: F841 | |
############# | |
# Make plot # | |
############# | |
pred_ec = np.tile( | |
np.expand_dims(np.asarray(pred_ec > 0, dtype=float), axis=2), reps=(1, 1, 3) | |
) | |
pred_ec_colmap = map_color(pred_ec, (1, 1, 1), color_pred) | |
pred_ec_colmap_ma = np.ma.masked_not_equal(pred_ec_colmap, color_pred) # noqa: F841 | |
label_ec = np.tile( | |
np.expand_dims(np.asarray(label_ec > 0, dtype=float), axis=2), reps=(1, 1, 3) | |
) | |
label_ec_colmap = map_color(label_ec, (1, 1, 1), color_must) | |
label_ec_colmap_ma = np.ma.masked_not_equal( # noqa: F841 | |
label_ec_colmap, color_must | |
) | |
# Combined pred and label edges | |
combined_ec = pred_ec_colmap + label_ec_colmap | |
combined_ec_ma = np.ma.masked_equal(combined_ec, (0, 0, 0)) | |
combined_ec_img = combined_ec_ma.mask * img + combined_ec | |
# Pred | |
pred_colmap = pred.astype(float) | |
pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred) | |
pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred) | |
# Must | |
label_colmap = label.astype(float) | |
label_colmap = map_color(label_colmap, (0, 0, 255), color_must) | |
label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_must) | |
# TP | |
tp_map = imread( | |
model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(srs_sel.filename).stem) | |
) | |
tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3)) | |
tp_map_colmap = tp_map.astype(float) | |
tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp) | |
tp_map_colmap_ma = np.ma.masked_not_equal(tp_map_colmap, color_tp) | |
# Combination | |
comb_pred = ( | |
(pred_colmap_ma.mask ^ tp_map_colmap_ma.mask) | |
& tp_map_colmap_ma.mask | |
& combined_ec_ma.mask | |
) * pred_colmap | |
comb_label = ( | |
(label_colmap_ma.mask ^ pred_colmap_ma.mask) | |
& pred_colmap_ma.mask | |
& combined_ec_ma.mask | |
) * label_colmap | |
comb_tp = combined_ec_ma.mask * tp_map_colmap.copy() | |
combined = comb_tp + comb_label + comb_pred | |
combined_ma = np.ma.masked_equal(combined, (0, 0, 0)) | |
combined_ma = combined_ma.mask * combined_ec_img + combined_ma | |
ax.imshow(combined_ec_img, alpha=1) | |
ax.imshow(combined_ma, alpha=0.5) | |
ax.axis("off") | |
# Plot lines | |
idx_sort_x = np.argsort(pred_ec_coord[:, 1]) | |
offset = 100 | |
for idx in range(offset, pred_ec_coord.shape[0], offset): | |
y0, x0 = pred_ec_coord[idx_sort_x[idx], :] | |
argmin = np.argmin(dist_mat[idx_sort_x[idx]]) | |
y1, x1 = label_ec_coord[argmin, :] | |
ax.plot([x0, x1], [y0, y1], color="white", linewidth=0.5) | |
if add_title: | |
ax.set_title("Metric", rotation=0, fontsize="x-large") | |
def plot_images_metric( | |
axes, metric, img_filename, img_id, n_, srs_sel, add_title, do_legend | |
): | |
# Read images | |
img_path = imgs_orig_path / img_filename | |
label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem) | |
img, label = crop_and_resize(img_path, label_path) | |
img = rgba2rgb(img) if img.shape[-1] == 4 else img / 255.0 | |
pred = imread( | |
model_path / "eval-metrics/pred" / "{}_pred.png".format(Path(img_filename).stem) | |
) | |
# Label | |
plot_labels(axes[0], img, label, img_id, n_, add_title, do_legend) | |
# Prediction | |
plot_pred(axes[1], img, pred, img_id, add_title, do_legend) | |
# Correct / incorrect | |
if metric in ["error", "f05"]: | |
plot_correct_incorrect( | |
axes[2], | |
img_filename, | |
img, | |
metric, | |
label, | |
img_id, | |
n_, | |
add_title, | |
do_legend=False, | |
) | |
handles = [] | |
lw = 1.0 | |
handles.append( | |
mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66) | |
) | |
handles.append( | |
mpatches.Patch(facecolor=color_tn, label="TN", linewidth=lw, alpha=0.66) | |
) | |
handles.append( | |
mpatches.Patch(facecolor=color_fp, label="FP", linewidth=lw, alpha=0.66) | |
) | |
handles.append( | |
mpatches.Patch(facecolor=color_fn, label="FN", linewidth=lw, alpha=0.66) | |
) | |
handles.append( | |
mpatches.Patch( | |
facecolor=color_may, | |
label="May-be-flooded", | |
linewidth=lw, | |
alpha=0.66, | |
) | |
) | |
labels = ["TP", "TN", "FP", "FN", "May-be-flooded"] | |
if metric == "error": | |
if n_ in [1, 3, 5]: | |
title = "Low error rate" | |
else: | |
title = "High error rate" | |
else: | |
if n_ in [1, 3, 5]: | |
title = "High F05 score" | |
else: | |
title = "Low F05 score" | |
# Edge coherence | |
elif metric == "edge_coherence": | |
plot_edge_coherence( | |
axes[2], img, metric, label, pred, img_id, n_, add_title, do_legend=False | |
) | |
handles = [] | |
lw = 1.0 | |
handles.append( | |
mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66) | |
) | |
handles.append( | |
mpatches.Patch(facecolor=color_pred, label="pred", linewidth=lw, alpha=0.66) | |
) | |
handles.append( | |
mpatches.Patch( | |
facecolor=color_must, | |
label="Must-be-flooded", | |
linewidth=lw, | |
alpha=0.66, | |
) | |
) | |
labels = ["TP", "Prediction", "Must-be-flooded"] | |
if n_ in [1, 3, 5]: | |
title = "High edge coherence" | |
else: | |
title = "Low edge coherence" | |
else: | |
raise ValueError | |
labels_values_title = "Error: {:.4f} \nFO5: {:.4f} \nEdge coherence: {:.4f}".format( | |
srs_sel.error, srs_sel.f05, srs_sel.edge_coherence | |
) | |
plot_legend(axes[3], img, handles, labels, labels_values_title, title) | |
def plot_legend(ax, img, handles, labels, labels_values_title, title): | |
img_ = np.zeros_like(img, dtype=np.uint8) | |
img_.fill(255) | |
ax.imshow(img_) | |
ax.axis("off") | |
leg1 = ax.legend( | |
handles=handles, | |
labels=labels, | |
title=title, | |
title_fontsize="medium", | |
labelspacing=0.6, | |
loc="upper left", | |
fontsize="x-small", | |
frameon=False, | |
) | |
leg1._legend_box.align = "left" | |
leg2 = ax.legend( | |
title=labels_values_title, | |
title_fontsize="small", | |
loc="lower left", | |
frameon=False, | |
) | |
leg2._legend_box.align = "left" | |
ax.add_artist(leg1) | |
def scatterplot_metrics_pair(ax, df, x_metric, y_metric, dict_images): | |
sns.scatterplot(data=df, x=x_metric, y=y_metric, ax=ax) | |
# Set X-label | |
ax.set_xlabel(dict_metrics["names"][x_metric], rotation=0, fontsize="medium") | |
# Set Y-label | |
ax.set_ylabel(dict_metrics["names"][y_metric], rotation=90, fontsize="medium") | |
# Change spines | |
sns.despine(ax=ax, left=True, bottom=True) | |
annotate_scatterplot(ax, dict_images, x_metric, y_metric) | |
def scatterplot_metrics(ax, df, df_all, dict_images, plot_all=False): | |
# Other | |
if plot_all: | |
sns.scatterplot( | |
data=df_all.loc[df_all.ground == True], | |
x="error", y="f05", hue="edge_coherence", ax=ax, | |
marker='+', alpha=0.25) | |
sns.scatterplot( | |
data=df_all.loc[df_all.instagan == True], | |
x="error", y="f05", hue="edge_coherence", ax=ax, | |
marker='x', alpha=0.25) | |
sns.scatterplot( | |
data=df_all.loc[(df_all.instagan == False) & (df_all.instagan == False) & | |
(df_all.model_feats != args.best_model)], | |
x="error", y="f05", hue="edge_coherence", ax=ax, | |
marker='s', alpha=0.25) | |
# Best model | |
cmap_ = sns.cubehelix_palette(as_cmap=True) | |
sns.scatterplot( | |
data=df, x="error", y="f05", hue="edge_coherence", ax=ax, palette=cmap_ | |
) | |
norm = plt.Normalize(df["edge_coherence"].min(), df["edge_coherence"].max()) | |
sm = plt.cm.ScalarMappable(cmap=cmap_, norm=norm) | |
sm.set_array([]) | |
# Remove the legend and add a colorbar | |
ax.get_legend().remove() | |
ax_cbar = ax.figure.colorbar(sm) | |
ax_cbar.set_label("Edge coherence", labelpad=8) | |
# Set X-label | |
ax.set_xlabel(dict_metrics["names"]["error"], rotation=0, fontsize="medium") | |
# Set Y-label | |
ax.set_ylabel(dict_metrics["names"]["f05"], rotation=90, fontsize="medium") | |
annotate_scatterplot(ax, dict_images, "error", "f05") | |
# Change spines | |
sns.despine(ax=ax, left=True, bottom=True) | |
# Set XY limits | |
xlim = ax.get_xlim() | |
ylim = ax.get_ylim() | |
ax.set_xlim([0.0, xlim[1]]) | |
ax.set_ylim([ylim[0], 1.0]) | |
def annotate_scatterplot(ax, dict_images, x_metric, y_metric, offset=0.1): | |
xlim = ax.get_xlim() | |
ylim = ax.get_ylim() | |
x_len = xlim[1] - xlim[0] | |
y_len = ylim[1] - ylim[0] | |
x_th = xlim[1] - x_len / 2.0 | |
y_th = ylim[1] - y_len / 2.0 | |
for text, d in dict_images.items(): | |
if text in ["B", "D", "F"]: | |
x = d[x_metric] | |
y = d[y_metric] | |
x_text = x + x_len * offset if x < x_th else x - x_len * offset | |
y_text = y + y_len * offset if y < y_th else y - y_len * offset | |
ax.annotate( | |
xy=(x, y), | |
xycoords="data", | |
xytext=(x_text, y_text), | |
textcoords="data", | |
text=text, | |
arrowprops=dict(facecolor="black", shrink=0.05), | |
fontsize="medium", | |
color="black", | |
) | |
elif text == "A": | |
x = ( | |
dict_images["A"][x_metric] | |
+ dict_images["C"][x_metric] | |
+ dict_images["E"][x_metric] | |
) / 3 | |
y = ( | |
dict_images["A"][y_metric] | |
+ dict_images["C"][y_metric] | |
+ dict_images["E"][y_metric] | |
) / 3 | |
x_text = x + x_len * 2 * offset if x < x_th else x - x_len * 2 * offset | |
y_text = ( | |
y + y_len * 0.45 * offset if y < y_th else y - y_len * 0.45 * offset | |
) | |
ax.annotate( | |
xy=(x, y), | |
xycoords="data", | |
xytext=(x_text, y_text), | |
textcoords="data", | |
text="A, C, E", | |
arrowprops=dict(facecolor="black", shrink=0.05), | |
fontsize="medium", | |
color="black", | |
) | |
if __name__ == "__main__": | |
# ----------------------------- | |
# ----- Parse arguments ----- | |
# ----------------------------- | |
args = parsed_args() | |
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) | |
# Determine output dir | |
if args.output_dir is None: | |
output_dir = Path(os.environ["SLURM_TMPDIR"]) | |
else: | |
output_dir = Path(args.output_dir) | |
if not output_dir.exists(): | |
output_dir.mkdir(parents=True, exist_ok=False) | |
# Store args | |
output_yml = output_dir / "labels.yml" | |
with open(output_yml, "w") as f: | |
yaml.dump(vars(args), f) | |
# Data dirs | |
imgs_orig_path = Path(args.masker_test_set_dir) / "imgs" | |
labels_path = Path(args.masker_test_set_dir) / "labels" | |
# Read CSV | |
df_all = pd.read_csv(args.input_csv, index_col="model_img_idx") | |
# Select best model | |
df = df_all.loc[df_all.model_feats == args.best_model] | |
v_key, model_dir = df.model.unique()[0].split("/") | |
model_path = Path(args.models_log_path) / "ablation-{}".format(v_key) / model_dir | |
# Set up plot | |
sns.reset_orig() | |
sns.set(style="whitegrid") | |
plt.rcParams.update({"font.family": "serif"}) | |
plt.rcParams.update( | |
{ | |
"font.serif": [ | |
"Computer Modern Roman", | |
"Times New Roman", | |
"Utopia", | |
"New Century Schoolbook", | |
"Century Schoolbook L", | |
"ITC Bookman", | |
"Bookman", | |
"Times", | |
"Palatino", | |
"Charter", | |
"serif" "Bitstream Vera Serif", | |
"DejaVu Serif", | |
] | |
} | |
) | |
if args.seed: | |
np.random.seed(args.seed) | |
img_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |
dict_images = {} | |
idx = 0 | |
# Define grid of subplots | |
grid_vmargin = 0.03 # Extent of the vertical margin between metric grids | |
ax_hspace = 0.04 # Extent of the vertical space between axes of same grid | |
ax_wspace = 0.05 # Extent of the horizontal space between axes of same grid | |
n_grids = len(metrics) | |
n_cols = 4 | |
n_rows = 2 | |
h_grid = (1.0 / n_grids) - ((n_grids - 1) * grid_vmargin) / n_grids | |
fig1 = plt.figure(dpi=200, figsize=(11, 13)) | |
n_ = 0 | |
add_title = False | |
for metric_id, metric in enumerate(metrics): | |
# Create grid | |
top_grid = 1.0 - metric_id * h_grid - metric_id * grid_vmargin | |
bottom_grid = top_grid - h_grid | |
gridspec = GridSpec( | |
n_rows, | |
n_cols, | |
wspace=ax_wspace, | |
hspace=ax_hspace, | |
bottom=bottom_grid, | |
top=top_grid, | |
) | |
# Select best | |
if metric == "error": | |
ascending = True | |
else: | |
ascending = False | |
idx_rand = np.random.permutation(int(args.percentile * len(df)))[0] | |
srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand] | |
img_id = img_ids[idx] | |
dict_images.update({img_id: srs_sel}) | |
# Read images | |
img_filename = srs_sel.filename | |
axes_row = [fig1.add_subplot(gridspec[0, c]) for c in range(n_cols)] | |
if not args.no_images: | |
n_ += 1 | |
if metric_id == 0: | |
add_title = True | |
plot_images_metric( | |
axes_row, | |
metric, | |
img_filename, | |
img_id, | |
n_, | |
srs_sel, | |
add_title=add_title, | |
do_legend=False, | |
) | |
add_title = False | |
idx += 1 | |
print("1 more row done.") | |
# Select worst | |
if metric == "error": | |
ascending = False | |
else: | |
ascending = True | |
idx_rand = np.random.permutation(int(args.percentile * len(df)))[0] | |
srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand] | |
img_id = img_ids[idx] | |
dict_images.update({img_id: srs_sel}) | |
# Read images | |
img_filename = srs_sel.filename | |
axes_row = [fig1.add_subplot(gridspec[1, c]) for c in range(n_cols)] | |
if not args.no_images: | |
n_ += 1 | |
plot_images_metric( | |
axes_row, | |
metric, | |
img_filename, | |
img_id, | |
n_, | |
srs_sel, | |
add_title=add_title, | |
do_legend=False, | |
) | |
idx += 1 | |
print("1 more row done.") | |
output_fig = output_dir / "all_metrics.png" | |
fig1.tight_layout() # (pad=1.5) # | |
fig1.savefig(output_fig, dpi=fig1.dpi, bbox_inches="tight") | |
# Scatter plot | |
fig2 = plt.figure(dpi=200) | |
scatterplot_metrics(fig2.gca(), df, df_all, dict_images) | |
# fig2, axes = plt.subplots(nrows=1, ncols=3, dpi=200, figsize=(18, 5)) | |
# | |
# scatterplot_metrics_pair(axes[0], df, "error", "f05", dict_images) | |
# scatterplot_metrics_pair(axes[1], df, "error", "edge_coherence", dict_images) | |
# scatterplot_metrics_pair(axes[2], df, "f05", "edge_coherence", dict_images) | |
output_fig = output_dir / "scatterplots.png" | |
fig2.savefig(output_fig, dpi=fig2.dpi, bbox_inches="tight") | |