|
""" |
|
This scripts plots images from the Masker test set overlaid with their labels. |
|
""" |
|
print("Imports...", end="") |
|
from argparse import ArgumentParser |
|
import os |
|
import yaml |
|
import numpy as np |
|
import pandas as pd |
|
import seaborn as sns |
|
from pathlib import Path |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as mpatches |
|
|
|
import sys |
|
|
|
sys.path.append("../") |
|
|
|
from eval_masker import crop_and_resize |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colorblind_palette = sns.color_palette("colorblind") |
|
color_cannot = colorblind_palette[1] |
|
color_must = colorblind_palette[2] |
|
color_may = colorblind_palette[7] |
|
color_pred = colorblind_palette[4] |
|
|
|
icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5) |
|
color_tp = icefire[0] |
|
color_tn = icefire[1] |
|
color_fp = icefire[4] |
|
color_fn = icefire[3] |
|
|
|
|
|
def parsed_args(): |
|
""" |
|
Parse and returns command-line args |
|
|
|
Returns: |
|
argparse.Namespace: the parsed arguments |
|
""" |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--input_csv", |
|
default="ablations_metrics_20210311.csv", |
|
type=str, |
|
help="CSV containing the results of the ablation study", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
default=None, |
|
type=str, |
|
help="Output directory", |
|
) |
|
parser.add_argument( |
|
"--masker_test_set_dir", |
|
default=None, |
|
type=str, |
|
help="Directory containing the test images", |
|
) |
|
parser.add_argument( |
|
"--images", |
|
nargs="+", |
|
help="List of image file names to plot", |
|
default=[], |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--dpi", |
|
default=200, |
|
type=int, |
|
help="DPI for the output images", |
|
) |
|
parser.add_argument( |
|
"--alpha", |
|
default=0.5, |
|
type=float, |
|
help="Transparency of labels shade", |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def map_color(arr, input_color, output_color, rtol=1e-09): |
|
""" |
|
Maps one color to another |
|
""" |
|
input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,))) |
|
output = arr.copy() |
|
output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color |
|
return output |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
args = parsed_args() |
|
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) |
|
|
|
|
|
if args.output_dir is None: |
|
output_dir = Path(os.environ["SLURM_TMPDIR"]) |
|
else: |
|
output_dir = Path(args.output_dir) |
|
if not output_dir.exists(): |
|
output_dir.mkdir(parents=True, exist_ok=False) |
|
|
|
|
|
output_yml = output_dir / "labels.yml" |
|
with open(output_yml, "w") as f: |
|
yaml.dump(vars(args), f) |
|
|
|
|
|
imgs_orig_path = Path(args.masker_test_set_dir) / "imgs" |
|
labels_path = Path(args.masker_test_set_dir) / "labels" |
|
|
|
|
|
df = pd.read_csv(args.input_csv, index_col="model_img_idx") |
|
|
|
|
|
sns.reset_orig() |
|
sns.set(style="whitegrid") |
|
plt.rcParams.update({"font.family": "serif"}) |
|
plt.rcParams.update( |
|
{ |
|
"font.serif": [ |
|
"Computer Modern Roman", |
|
"Times New Roman", |
|
"Utopia", |
|
"New Century Schoolbook", |
|
"Century Schoolbook L", |
|
"ITC Bookman", |
|
"Bookman", |
|
"Times", |
|
"Palatino", |
|
"Charter", |
|
"serif" "Bitstream Vera Serif", |
|
"DejaVu Serif", |
|
] |
|
} |
|
) |
|
|
|
fig, axes = plt.subplots( |
|
nrows=1, ncols=len(args.images), dpi=args.dpi, figsize=(len(args.images) * 5, 5) |
|
) |
|
|
|
for idx, img_filename in enumerate(args.images): |
|
|
|
|
|
img_path = imgs_orig_path / img_filename |
|
label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem) |
|
img, label = crop_and_resize(img_path, label_path) |
|
|
|
|
|
label_colmap = label.astype(float) |
|
label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot) |
|
label_colmap = map_color(label_colmap, (0, 0, 255), color_must) |
|
label_colmap = map_color(label_colmap, (0, 0, 0), color_may) |
|
|
|
ax = axes[idx] |
|
ax.imshow(img) |
|
ax.imshow(label_colmap, alpha=args.alpha) |
|
ax.axis("off") |
|
|
|
|
|
handles = [] |
|
lw = 1.0 |
|
handles.append( |
|
mpatches.Patch( |
|
facecolor=color_must, label="must", linewidth=lw, alpha=args.alpha |
|
) |
|
) |
|
handles.append( |
|
mpatches.Patch(facecolor=color_may, label="may", linewidth=lw, alpha=args.alpha) |
|
) |
|
handles.append( |
|
mpatches.Patch( |
|
facecolor=color_cannot, label="cannot", linewidth=lw, alpha=args.alpha |
|
) |
|
) |
|
labels = ["Must-be-flooded", "May-be-flooded", "Cannot-be-flooded"] |
|
fig.legend( |
|
handles=handles, |
|
labels=labels, |
|
loc="upper center", |
|
bbox_to_anchor=(0.0, 0.85, 1.0, 0.075), |
|
ncol=len(args.images), |
|
fontsize="medium", |
|
frameon=False, |
|
) |
|
|
|
|
|
output_fig = output_dir / "labels.png" |
|
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") |
|
|