climateGAN / eval_masker.py
vict0rsch's picture
initial commit from cc-ai/climateGAN
448ebbd
raw
history blame
26 kB
"""
Compute metrics of the performance of the masker using a set of ground-truth labels
run eval_masker.py --model "/miniscratch/_groups/ccai/checkpoints/model/"
"""
print("Imports...", end="")
import os
import os.path
from argparse import ArgumentParser
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from comet_ml import Experiment
import torch
import yaml
from skimage.color import rgba2rgb
from skimage.io import imread, imsave
from skimage.transform import resize
from skimage.util import img_as_ubyte
from torchvision.transforms import ToTensor
from climategan.data import encode_mask_label
from climategan.eval_metrics import (
masker_classification_metrics,
get_confusion_matrix,
edges_coherence_std_min,
boxplot_metric,
clustermap_metric,
)
from climategan.transforms import PrepareTest
from climategan.trainer import Trainer
from climategan.utils import find_images
dict_metrics = {
"names": {
"tpr": "TPR, Recall, Sensitivity",
"tnr": "TNR, Specificity, Selectivity",
"fpr": "FPR",
"fpt": "False positives relative to image size",
"fnr": "FNR, Miss rate",
"fnt": "False negatives relative to image size",
"mpr": "May positive rate (MPR)",
"mnr": "May negative rate (MNR)",
"accuracy": "Accuracy (ignoring may)",
"error": "Error (ignoring may)",
"f05": "F0.05 score",
"precision": "Precision",
"edge_coherence": "Edge coherence",
"accuracy_must_may": "Accuracy (ignoring cannot)",
},
"threshold": {
"tpr": 0.95,
"tnr": 0.95,
"fpr": 0.05,
"fpt": 0.01,
"fnr": 0.05,
"fnt": 0.01,
"accuracy": 0.95,
"error": 0.05,
"f05": 0.95,
"precision": 0.95,
"edge_coherence": 0.02,
"accuracy_must_may": 0.5,
},
"key_metrics": ["f05", "error", "edge_coherence", "mnr"],
}
print("Ok.")
def parsed_args():
"""Parse and returns command-line args
Returns:
argparse.Namespace: the parsed arguments
"""
parser = ArgumentParser()
parser.add_argument(
"--model",
type=str,
help="Path to a pre-trained model",
)
parser.add_argument(
"--images_dir",
default="/miniscratch/_groups/ccai/data/omnigan/masker-test-set/imgs",
type=str,
help="Directory containing the original test images",
)
parser.add_argument(
"--labels_dir",
default="/miniscratch/_groups/ccai/data/omnigan/masker-test-set/labels",
type=str,
help="Directory containing the labeled images",
)
parser.add_argument(
"--image_size",
default=640,
type=int,
help="The height and weight of the pre-processed images",
)
parser.add_argument(
"--max_files",
default=-1,
type=int,
help="Limit loaded samples",
)
parser.add_argument(
"--bin_value", default=0.5, type=float, help="Mask binarization threshold"
)
parser.add_argument(
"-y",
"--yaml",
default=None,
type=str,
help="load a yaml file to parametrize the evaluation",
)
parser.add_argument(
"-t", "--tags", nargs="*", help="Comet.ml tags", default=[], type=str
)
parser.add_argument(
"-p",
"--plot",
action="store_true",
default=False,
help="Plot masker images & their metrics overlays",
)
parser.add_argument(
"--no_paint",
action="store_true",
default=False,
help="Do not log painted images",
)
parser.add_argument(
"--write_metrics",
action="store_true",
default=False,
help="If True, write CSV file and maps images in model's path directory",
)
parser.add_argument(
"--load_metrics",
action="store_true",
default=False,
help="If True, load predictions and metrics instead of re-computing",
)
parser.add_argument(
"--prepare_torch",
action="store_true",
default=False,
help="If True, pre-process images as torch tensors",
)
parser.add_argument(
"--output_csv",
default=None,
type=str,
help="Filename of the output CSV with the metrics of all models",
)
return parser.parse_args()
def uint8(array):
return array.astype(np.uint8)
def crop_and_resize(image_path, label_path):
"""
Resizes an image so that it keeps the aspect ratio and the smallest dimensions
is 640, then crops this resized image in its center so that the output is 640x640
without aspect ratio distortion
Args:
image_path (Path or str): Path to an image
label_path (Path or str): Path to the image's associated label
Returns:
tuple((np.ndarray, np.ndarray)): (new image, new label)
"""
img = imread(image_path)
lab = imread(label_path)
# if img.shape[-1] == 4:
# img = uint8(rgba2rgb(img) * 255)
# TODO: remove (debug)
if img.shape[:2] != lab.shape[:2]:
print(
"\nWARNING: shape mismatch: im -> ({}) {}, lab -> ({}) {}".format(
img.shape[:2], image_path.name, lab.shape[:2], label_path.name
)
)
# breakpoint()
# resize keeping aspect ratio: smallest dim is 640
i_h, i_w = img.shape[:2]
if i_h < i_w:
i_size = (640, int(640 * i_w / i_h))
else:
i_size = (int(640 * i_h / i_w), 640)
l_h, l_w = img.shape[:2]
if l_h < l_w:
l_size = (640, int(640 * l_w / l_h))
else:
l_size = (int(640 * l_h / l_w), 640)
r_img = resize(img, i_size, preserve_range=True, anti_aliasing=True)
r_img = uint8(r_img)
r_lab = resize(lab, l_size, preserve_range=True, anti_aliasing=False, order=0)
r_lab = uint8(r_lab)
# crop in the center
H, W = r_img.shape[:2]
top = (H - 640) // 2
left = (W - 640) // 2
rc_img = r_img[top : top + 640, left : left + 640, :]
rc_lab = (
r_lab[top : top + 640, left : left + 640, :]
if r_lab.ndim == 3
else r_lab[top : top + 640, left : left + 640]
)
return rc_img, rc_lab
def plot_images(
output_filename,
img,
label,
pred,
metrics_dict,
maps_dict,
edge_coherence=-1,
pred_edge=None,
label_edge=None,
dpi=300,
alpha=0.5,
vmin=0.0,
vmax=1.0,
fontsize="xx-small",
cmap={
"fp": "Reds",
"fn": "Reds",
"may_neg": "Oranges",
"may_pos": "Purples",
"pred": "Greens",
},
):
f, axes = plt.subplots(1, 5, dpi=dpi)
# FPR (predicted mask on cannot flood)
axes[0].imshow(img)
fp_map_plt = axes[0].imshow( # noqa: F841
maps_dict["fp"], vmin=vmin, vmax=vmax, cmap=cmap["fp"], alpha=alpha
)
axes[0].axis("off")
axes[0].set_title("FPR: {:.4f}".format(metrics_dict["fpr"]), fontsize=fontsize)
# FNR (missed mask on must flood)
axes[1].imshow(img)
fn_map_plt = axes[1].imshow( # noqa: F841
maps_dict["fn"], vmin=vmin, vmax=vmax, cmap=cmap["fn"], alpha=alpha
)
axes[1].axis("off")
axes[1].set_title("FNR: {:.4f}".format(metrics_dict["fnr"]), fontsize=fontsize)
# May flood
axes[2].imshow(img)
if edge_coherence != -1:
title = "MNR: {:.2f} | MPR: {:.2f}\nEdge coh.: {:.4f}".format(
metrics_dict["mnr"], metrics_dict["mpr"], edge_coherence
)
# alpha_here = alpha / 4.
# pred_edge_plt = axes[2].imshow(
# 1.0 - pred_edge, cmap="gray", alpha=alpha_here
# )
# label_edge_plt = axes[2].imshow(
# 1.0 - label_edge, cmap="gray", alpha=alpha_here
# )
else:
title = "MNR: {:.2f} | MPR: {:.2f}".format(mnr, mpr) # noqa: F821
# alpha_here = alpha / 2.
may_neg_map_plt = axes[2].imshow( # noqa: F841
maps_dict["may_neg"], vmin=vmin, vmax=vmax, cmap=cmap["may_neg"], alpha=alpha
)
may_pos_map_plt = axes[2].imshow( # noqa: F841
maps_dict["may_pos"], vmin=vmin, vmax=vmax, cmap=cmap["may_pos"], alpha=alpha
)
axes[2].set_title(title, fontsize=fontsize)
axes[2].axis("off")
# Prediction
axes[3].imshow(img)
pred_mask = axes[3].imshow( # noqa: F841
pred, vmin=vmin, vmax=vmax, cmap=cmap["pred"], alpha=alpha
)
axes[3].set_title("Predicted mask", fontsize=fontsize)
axes[3].axis("off")
# Labels
axes[4].imshow(img)
label_mask = axes[4].imshow(label, alpha=alpha) # noqa: F841
axes[4].set_title("Labels", fontsize=fontsize)
axes[4].axis("off")
f.savefig(
output_filename,
dpi=f.dpi,
bbox_inches="tight",
facecolor="white",
transparent=False,
)
plt.close(f)
def load_ground(ground_output_path, ref_image_path):
gop = Path(ground_output_path)
rip = Path(ref_image_path)
ground_paths = list((gop / "eval-metrics" / "pred").glob(f"{rip.stem}.jpg")) + list(
(gop / "eval-metrics" / "pred").glob(f"{rip.stem}.png")
)
if len(ground_paths) == 0:
raise ValueError(
f"Could not find a ground match in {str(gop)} for image {str(rip)}"
)
elif len(ground_paths) > 1:
raise ValueError(
f"Found more than 1 ground match in {str(gop)} for image {str(rip)}:"
+ f" {list(map(str, ground_paths))}"
)
ground_path = ground_paths[0]
_, ground = crop_and_resize(rip, ground_path)
if ground.ndim == 3:
ground = ground[:, :, 0]
ground = (ground > 0).astype(np.float32)
return torch.from_numpy(ground).unsqueeze(0).unsqueeze(0).cuda()
def get_inferences(
image_arrays, model_path, image_paths, paint=False, bin_value=0.5, verbose=0
):
"""
Obtains the mask predictions of a model for a set of images
Parameters
----------
image_arrays : array-like
A list of (1, CH, H, W) images
image_paths: list(Path)
A list of paths for images, in the same order as image_arrays
model_path : str
The path to a pre-trained model
Returns
-------
masks : list
A list of (H, W) predicted masks
"""
device = torch.device("cuda:0")
torch.set_grad_enabled(False)
to_tensor = ToTensor()
is_ground = "ground" in Path(model_path).name
is_instagan = "instagan" in Path(model_path).name
if is_ground or is_instagan:
# we just care about he painter here
ground_path = model_path
model_path = (
"/miniscratch/_groups/ccai/experiments/runs/ablation-v1/out--38858350"
)
xs = [to_tensor(array).unsqueeze(0) for array in image_arrays]
xs = [x.to(torch.float32).to(device) for x in xs]
xs = [(x - 0.5) * 2 for x in xs]
trainer = Trainer.resume_from_path(
model_path, inference=True, new_exp=None, device=device
)
masks = []
painted = []
for idx, x in enumerate(xs):
if verbose > 0:
print(idx, "/", len(xs), end="\r")
if not is_ground and not is_instagan:
m = trainer.G.mask(x=x)
else:
m = load_ground(ground_path, image_paths[idx])
masks.append(m.squeeze().cpu())
if paint:
p = trainer.G.paint(m > bin_value, x)
painted.append(p.squeeze().cpu())
return masks, painted
if __name__ == "__main__":
# -----------------------------
# ----- Parse arguments -----
# -----------------------------
args = parsed_args()
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
# Determine output dir
try:
tmp_dir = Path(os.environ["SLURM_TMPDIR"])
except Exception as e:
print(e)
tmp_dir = Path(input("Enter tmp output directory: ")).resolve()
plot_dir = tmp_dir / "plots"
plot_dir.mkdir(parents=True, exist_ok=True)
# Build paths to data
imgs_paths = sorted(
find_images(args.images_dir, recursive=False), key=lambda x: x.name
)
labels_paths = sorted(
find_images(args.labels_dir, recursive=False),
key=lambda x: x.name.replace("_labeled.", "."),
)
if args.max_files > 0:
imgs_paths = imgs_paths[: args.max_files]
labels_paths = labels_paths[: args.max_files]
print(f"Loading {len(imgs_paths)} images and labels...")
# Pre-process images: resize + crop
# TODO: ? make cropping more flexible, not only central
if not args.prepare_torch:
ims_labs = [crop_and_resize(i, l) for i, l in zip(imgs_paths, labels_paths)]
imgs = [d[0] for d in ims_labs]
labels = [d[1] for d in ims_labs]
else:
prepare = PrepareTest()
imgs = prepare(imgs_paths, normalize=False, rescale=False)
labels = prepare(labels_paths, normalize=False, rescale=False)
imgs = [i.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) for i in imgs]
labels = [
lab.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) for lab in labels
]
imgs = [rgba2rgb(img) if img.shape[-1] == 4 else img for img in imgs]
print(" Done.")
# Encode labels
print("Encode labels...", end="", flush=True)
# HW label
labels = [np.squeeze(encode_mask_label(label, "flood")) for label in labels]
print("Done.")
if args.yaml:
y_path = Path(args.yaml)
assert y_path.exists()
assert y_path.suffix in {".yaml", ".yml"}
with y_path.open("r") as f:
data = yaml.safe_load(f)
assert "models" in data
evaluations = [m for m in data["models"]]
else:
evaluations = [args.model]
for e, eval_path in enumerate(evaluations):
print("\n>>>>> Evaluation", e, ":", eval_path)
print("=" * 50)
print("=" * 50)
model_metrics_path = Path(eval_path) / "eval-metrics"
model_metrics_path.mkdir(exist_ok=True)
if args.load_metrics:
f_csv = model_metrics_path / "eval_masker.csv"
pred_out = model_metrics_path / "pred"
if f_csv.exists() and pred_out.exists():
print("Skipping model because pre-computed metrics exist")
continue
# Initialize New Comet Experiment
exp = Experiment(
project_name="climategan-masker-metrics", display_summary_level=0
)
# Obtain mask predictions
# TODO: remove (debug)
print("Obtain mask predictions", end="", flush=True)
preds, painted = get_inferences(
imgs,
eval_path,
imgs_paths,
paint=not args.no_paint,
bin_value=args.bin_value,
verbose=1,
)
preds = [pred.numpy() for pred in preds]
print(" Done.")
if args.bin_value > 0:
preds = [pred > args.bin_value for pred in preds]
# Compute metrics
df = pd.DataFrame(
columns=[
"tpr",
"tpt",
"tnr",
"tnt",
"fpr",
"fpt",
"fnr",
"fnt",
"mnr",
"mpr",
"accuracy",
"error",
"precision",
"f05",
"accuracy_must_may",
"edge_coherence",
"filename",
]
)
print("Compute metrics and plot images")
for idx, (img, label, pred) in enumerate(zip(*(imgs, labels, preds))):
print(idx, "/", len(imgs), end="\r")
# Basic classification metrics
metrics_dict, maps_dict = masker_classification_metrics(
pred, label, labels_dict={"cannot": 0, "must": 1, "may": 2}
)
# Edges coherence
edge_coherence, pred_edge, label_edge = edges_coherence_std_min(pred, label)
series_dict = {
"tpr": metrics_dict["tpr"],
"tpt": metrics_dict["tpt"],
"tnr": metrics_dict["tnr"],
"tnt": metrics_dict["tnt"],
"fpr": metrics_dict["fpr"],
"fpt": metrics_dict["fpt"],
"fnr": metrics_dict["fnr"],
"fnt": metrics_dict["fnt"],
"mnr": metrics_dict["mnr"],
"mpr": metrics_dict["mpr"],
"accuracy": metrics_dict["accuracy"],
"error": metrics_dict["error"],
"precision": metrics_dict["precision"],
"f05": metrics_dict["f05"],
"accuracy_must_may": metrics_dict["accuracy_must_may"],
"edge_coherence": edge_coherence,
"filename": str(imgs_paths[idx].name),
}
df.loc[idx] = pd.Series(series_dict)
for k, v in series_dict.items():
if k == "filename":
continue
exp.log_metric(f"img_{k}", v, step=idx)
# Confusion matrix
confmat, _ = get_confusion_matrix(
metrics_dict["tpr"],
metrics_dict["tnr"],
metrics_dict["fpr"],
metrics_dict["fnr"],
metrics_dict["mnr"],
metrics_dict["mpr"],
)
confmat = np.around(confmat, decimals=3)
exp.log_confusion_matrix(
file_name=imgs_paths[idx].name + ".json",
title=imgs_paths[idx].name,
matrix=confmat,
labels=["Cannot", "Must", "May"],
row_label="Predicted",
column_label="Ground truth",
)
if args.plot:
# Plot prediction images
fig_filename = plot_dir / imgs_paths[idx].name
plot_images(
fig_filename,
img,
label,
pred,
metrics_dict,
maps_dict,
edge_coherence,
pred_edge,
label_edge,
)
exp.log_image(fig_filename)
if not args.no_paint:
masked = img * (1 - pred[..., None])
flooded = img_as_ubyte(
(painted[idx].permute(1, 2, 0).cpu().numpy() + 1) / 2
)
combined = np.concatenate([img, masked, flooded], 1)
exp.log_image(combined, imgs_paths[idx].name)
if args.write_metrics:
pred_out = model_metrics_path / "pred"
pred_out.mkdir(exist_ok=True)
imsave(
pred_out / f"{imgs_paths[idx].stem}_pred.png",
pred.astype(np.uint8),
)
for k, v in maps_dict.items():
metric_out = model_metrics_path / k
metric_out.mkdir(exist_ok=True)
imsave(
metric_out / f"{imgs_paths[idx].stem}_{k}.png",
v.astype(np.uint8),
)
# --------------------------------
# ----- END OF IMAGES LOOP -----
# --------------------------------
if args.write_metrics:
print(f"Writing metrics in {str(model_metrics_path)}")
f_csv = model_metrics_path / "eval_masker.csv"
df.to_csv(f_csv, index_label="idx")
print(" Done.")
# Summary statistics
means = df.mean(axis=0)
confmat_mean, confmat_std = get_confusion_matrix(
df.tpr, df.tnr, df.fpr, df.fnr, df.mpr, df.mnr
)
confmat_mean = np.around(confmat_mean, decimals=3)
confmat_std = np.around(confmat_std, decimals=3)
# Log to comet
exp.log_confusion_matrix(
file_name="confusion_matrix_mean.json",
title="confusion_matrix_mean.json",
matrix=confmat_mean,
labels=["Cannot", "Must", "May"],
row_label="Predicted",
column_label="Ground truth",
)
exp.log_confusion_matrix(
file_name="confusion_matrix_std.json",
title="confusion_matrix_std.json",
matrix=confmat_std,
labels=["Cannot", "Must", "May"],
row_label="Predicted",
column_label="Ground truth",
)
exp.log_metrics(dict(means))
exp.log_table("metrics.csv", df)
exp.log_html(df.to_html(col_space="80px"))
exp.log_parameters(vars(args))
exp.log_parameter("eval_path", str(eval_path))
exp.add_tag("eval_masker")
if args.tags:
exp.add_tags(args.tags)
exp.log_parameter("model_id", Path(eval_path).name)
# Close comet
exp.end()
# --------------------------------
# ----- END OF MODElS LOOP -----
# --------------------------------
# Compare models
if (args.load_metrics or args.write_metrics) and len(evaluations) > 1:
print(
"Plots for comparing the input models will be created and logged to comet"
)
# Initialize New Comet Experiment
exp = Experiment(
project_name="climategan-masker-metrics", display_summary_level=0
)
if args.tags:
exp.add_tags(args.tags)
# Build DataFrame with all models
print("Building pandas DataFrame...")
models_df = {}
for (m, model_path) in enumerate(evaluations):
model_path = Path(model_path)
with open(model_path / "opts.yaml", "r") as f:
opt = yaml.safe_load(f)
model_feats = ", ".join(
[
t
for t in sorted(opt["comet"]["tags"])
if "branch" not in t and "ablation" not in t and "trash" not in t
]
)
model_id = f"{model_path.parent.name[-2:]}/{model_path.name}"
df_m = pd.read_csv(
model_path / "eval-metrics" / "eval_masker.csv", index_col=False
)
df_m["model"] = [model_id] * len(df_m)
df_m["model_idx"] = [m] * len(df_m)
df_m["model_feats"] = [model_feats] * len(df_m)
models_df.update({model_id: df_m})
df = pd.concat(list(models_df.values()), ignore_index=True)
df["model_img_idx"] = df.model.astype(str) + "-" + df.idx.astype(str)
df.rename(columns={"idx": "img_idx"}, inplace=True)
dict_models_labels = {
k: f"{v['model_idx'][0]}: {v['model_feats'][0]}"
for k, v in models_df.items()
}
print("Done")
if args.output_csv:
print(f"Writing DataFrame to {args.output_csv}")
df.to_csv(args.output_csv, index_label="model_img_idx")
# Determine images with low metrics in any model
print("Constructing filter based on metrics thresholds...")
idx_not_good_in_any = []
for idx in df.img_idx.unique():
df_th = df.loc[
(
# TODO: rethink thresholds
(df.tpr <= dict_metrics["threshold"]["tpr"])
| (df.fpr >= dict_metrics["threshold"]["fpr"])
| (df.edge_coherence >= dict_metrics["threshold"]["edge_coherence"])
)
& ((df.img_idx == idx) & (df.model.isin(df.model.unique())))
]
if len(df_th) > 0:
idx_not_good_in_any.append(idx)
filters = {"all": df.img_idx.unique(), "not_good_in_any": idx_not_good_in_any}
print("Done")
# Boxplots of metrics
print("Plotting boxplots of metrics...")
for k, f in filters.items():
print(f"\tDistribution of [{k}] images...")
for metric in dict_metrics["names"].keys():
fig_filename = plot_dir / f"boxplot_{metric}_{k}.png"
if metric in ["mnr", "mpr", "accuracy_must_may"]:
boxplot_metric(
fig_filename,
df.loc[df.img_idx.isin(f)],
metric=metric,
dict_metrics=dict_metrics["names"],
do_stripplot=True,
dict_models=dict_models_labels,
order=list(df.model.unique()),
)
else:
boxplot_metric(
fig_filename,
df.loc[df.img_idx.isin(f)],
metric=metric,
dict_metrics=dict_metrics["names"],
dict_models=dict_models_labels,
fliersize=1.0,
order=list(df.model.unique()),
)
exp.log_image(fig_filename)
print("Done")
# Cluster Maps
print("Plotting clustermaps...")
for k, f in filters.items():
print(f"\tDistribution of [{k}] images...")
for metric in dict_metrics["names"].keys():
fig_filename = plot_dir / f"clustermap_{metric}_{k}.png"
df_mf = df.loc[df.img_idx.isin(f)].pivot("img_idx", "model", metric)
clustermap_metric(
output_filename=fig_filename,
df=df_mf,
metric=metric,
dict_metrics=dict_metrics["names"],
method="average",
cluster_metric="euclidean",
dict_models=dict_models_labels,
row_cluster=False,
)
exp.log_image(fig_filename)
print("Done")
# Close comet
exp.end()