Spaces:
Runtime error
Runtime error
""" | |
This scripts plots images from the Masker test set overlaid with their labels. | |
""" | |
print("Imports...", end="") | |
from argparse import ArgumentParser | |
import os | |
import yaml | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as mpatches | |
import sys | |
sys.path.append("../") | |
from eval_masker import crop_and_resize | |
# ----------------------- | |
# ----- Constants ----- | |
# ----------------------- | |
# Colors | |
colorblind_palette = sns.color_palette("colorblind") | |
color_cannot = colorblind_palette[1] | |
color_must = colorblind_palette[2] | |
color_may = colorblind_palette[7] | |
color_pred = colorblind_palette[4] | |
icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5) | |
color_tp = icefire[0] | |
color_tn = icefire[1] | |
color_fp = icefire[4] | |
color_fn = icefire[3] | |
def parsed_args(): | |
""" | |
Parse and returns command-line args | |
Returns: | |
argparse.Namespace: the parsed arguments | |
""" | |
parser = ArgumentParser() | |
parser.add_argument( | |
"--input_csv", | |
default="ablations_metrics_20210311.csv", | |
type=str, | |
help="CSV containing the results of the ablation study", | |
) | |
parser.add_argument( | |
"--output_dir", | |
default=None, | |
type=str, | |
help="Output directory", | |
) | |
parser.add_argument( | |
"--masker_test_set_dir", | |
default=None, | |
type=str, | |
help="Directory containing the test images", | |
) | |
parser.add_argument( | |
"--images", | |
nargs="+", | |
help="List of image file names to plot", | |
default=[], | |
type=str, | |
) | |
parser.add_argument( | |
"--dpi", | |
default=200, | |
type=int, | |
help="DPI for the output images", | |
) | |
parser.add_argument( | |
"--alpha", | |
default=0.5, | |
type=float, | |
help="Transparency of labels shade", | |
) | |
return parser.parse_args() | |
def map_color(arr, input_color, output_color, rtol=1e-09): | |
""" | |
Maps one color to another | |
""" | |
input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,))) | |
output = arr.copy() | |
output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color | |
return output | |
if __name__ == "__main__": | |
# ----------------------------- | |
# ----- Parse arguments ----- | |
# ----------------------------- | |
args = parsed_args() | |
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) | |
# Determine output dir | |
if args.output_dir is None: | |
output_dir = Path(os.environ["SLURM_TMPDIR"]) | |
else: | |
output_dir = Path(args.output_dir) | |
if not output_dir.exists(): | |
output_dir.mkdir(parents=True, exist_ok=False) | |
# Store args | |
output_yml = output_dir / "labels.yml" | |
with open(output_yml, "w") as f: | |
yaml.dump(vars(args), f) | |
# Data dirs | |
imgs_orig_path = Path(args.masker_test_set_dir) / "imgs" | |
labels_path = Path(args.masker_test_set_dir) / "labels" | |
# Read CSV | |
df = pd.read_csv(args.input_csv, index_col="model_img_idx") | |
# Set up plot | |
sns.reset_orig() | |
sns.set(style="whitegrid") | |
plt.rcParams.update({"font.family": "serif"}) | |
plt.rcParams.update( | |
{ | |
"font.serif": [ | |
"Computer Modern Roman", | |
"Times New Roman", | |
"Utopia", | |
"New Century Schoolbook", | |
"Century Schoolbook L", | |
"ITC Bookman", | |
"Bookman", | |
"Times", | |
"Palatino", | |
"Charter", | |
"serif" "Bitstream Vera Serif", | |
"DejaVu Serif", | |
] | |
} | |
) | |
fig, axes = plt.subplots( | |
nrows=1, ncols=len(args.images), dpi=args.dpi, figsize=(len(args.images) * 5, 5) | |
) | |
for idx, img_filename in enumerate(args.images): | |
# Read images | |
img_path = imgs_orig_path / img_filename | |
label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem) | |
img, label = crop_and_resize(img_path, label_path) | |
# Map label colors | |
label_colmap = label.astype(float) | |
label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot) | |
label_colmap = map_color(label_colmap, (0, 0, 255), color_must) | |
label_colmap = map_color(label_colmap, (0, 0, 0), color_may) | |
ax = axes[idx] | |
ax.imshow(img) | |
ax.imshow(label_colmap, alpha=args.alpha) | |
ax.axis("off") | |
# Legend | |
handles = [] | |
lw = 1.0 | |
handles.append( | |
mpatches.Patch( | |
facecolor=color_must, label="must", linewidth=lw, alpha=args.alpha | |
) | |
) | |
handles.append( | |
mpatches.Patch(facecolor=color_may, label="may", linewidth=lw, alpha=args.alpha) | |
) | |
handles.append( | |
mpatches.Patch( | |
facecolor=color_cannot, label="cannot", linewidth=lw, alpha=args.alpha | |
) | |
) | |
labels = ["Must-be-flooded", "May-be-flooded", "Cannot-be-flooded"] | |
fig.legend( | |
handles=handles, | |
labels=labels, | |
loc="upper center", | |
bbox_to_anchor=(0.0, 0.85, 1.0, 0.075), | |
ncol=len(args.images), | |
fontsize="medium", | |
frameon=False, | |
) | |
# Save figure | |
output_fig = output_dir / "labels.png" | |
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") | |