|
import os |
|
import csv |
|
|
|
from tqdm import tqdm |
|
from multiprocess import Pool, Process, Queue |
|
from functools import partial |
|
import time |
|
|
|
import torch |
|
import torch.utils.data |
|
|
|
|
|
from . import inference, save_utils, polygonize |
|
from . import local_utils |
|
from . import measures |
|
|
|
from lydorn_utils import run_utils |
|
from lydorn_utils import python_utils |
|
from lydorn_utils import print_utils |
|
from lydorn_utils import async_utils |
|
|
|
|
|
class Evaluator: |
|
def __init__(self, gpu: int, config: dict, shared_dict, barrier, model, run_dirpath): |
|
self.gpu = gpu |
|
self.config = config |
|
assert 0 < self.config["eval_params"]["batch_size_mult"], \ |
|
"batch_size_mult in polygonize_params should be at least 1." |
|
|
|
self.shared_dict = shared_dict |
|
self.barrier = barrier |
|
self.model = model |
|
|
|
self.checkpoints_dirpath = run_utils.setup_run_subdir(run_dirpath, |
|
config["optim_params"]["checkpoints_dirname"]) |
|
|
|
self.eval_dirpath = os.path.join(config["data_root_dir"], "eval_runs", os.path.split(run_dirpath)[-1]) |
|
if self.gpu == 0: |
|
os.makedirs(self.eval_dirpath, exist_ok=True) |
|
print_utils.print_info("Saving eval outputs to {}".format(self.eval_dirpath)) |
|
|
|
|
|
def evaluate(self, split_name: str, ds: torch.utils.data.DataLoader): |
|
|
|
|
|
flag_filepath_format = os.path.join(self.eval_dirpath, split_name, "{}.flag") |
|
|
|
|
|
self.load_checkpoint() |
|
self.model.eval() |
|
|
|
|
|
pool = None |
|
if not self.config["eval_params"]["patch_size"]: |
|
|
|
pool = Pool(processes=self.config["num_workers"]) |
|
|
|
compute_polygonization = self.config["eval_params"]["save_individual_outputs"]["poly_shapefile"] or \ |
|
self.config["eval_params"]["save_individual_outputs"]["poly_geojson"] or \ |
|
self.config["eval_params"]["save_individual_outputs"]["poly_viz"] or \ |
|
self.config["eval_params"]["save_aggregated_outputs"]["poly_coco"] |
|
|
|
|
|
save_individual_outputs = True in self.config["eval_params"]["save_individual_outputs"].values() |
|
saver_async = None |
|
if save_individual_outputs: |
|
save_outputs_partial = partial(save_utils.save_outputs, config=self.config, eval_dirpath=self.eval_dirpath, |
|
split_name=split_name, flag_filepath_format=flag_filepath_format) |
|
saver_async = async_utils.Async(save_outputs_partial) |
|
saver_async.start() |
|
|
|
|
|
save_aggregated_outputs = True in self.config["eval_params"]["save_aggregated_outputs"].values() |
|
|
|
tile_data_list = [] |
|
|
|
if self.gpu == 0: |
|
tile_iterator = tqdm(ds, desc="Eval {}: ".format(split_name), leave=True) |
|
else: |
|
tile_iterator = ds |
|
for tile_i, tile_data in enumerate(tile_iterator): |
|
|
|
if self.config["eval_params"]["patch_size"] is not None: |
|
|
|
inference.inference_with_patching(self.config, self.model, tile_data) |
|
else: |
|
|
|
inference.inference_no_patching(self.config, self.model, tile_data) |
|
|
|
tile_data_list.append(tile_data) |
|
|
|
|
|
if self.config["eval_params"]["batch_size_mult"] <= len(tile_data_list)\ |
|
or tile_i == len(tile_iterator) - 1: |
|
|
|
accumulated_tile_data = {} |
|
for key in tile_data_list[0].keys(): |
|
if isinstance(tile_data_list[0][key], list): |
|
accumulated_tile_data[key] = [item for _tile_data in tile_data_list for item in _tile_data[key]] |
|
elif isinstance(tile_data_list[0][key], torch.Tensor): |
|
accumulated_tile_data[key] = torch.cat([_tile_data[key] for _tile_data in tile_data_list], dim=0) |
|
else: |
|
raise TypeError(f"Type {type(tile_data_list[0][key])} is not handled!") |
|
tile_data_list = [] |
|
else: |
|
|
|
continue |
|
|
|
|
|
if compute_polygonization: |
|
crossfield = accumulated_tile_data["crossfield"] if "crossfield" in accumulated_tile_data else None |
|
accumulated_tile_data["polygons"], accumulated_tile_data["polygon_probs"] = polygonize.polygonize( |
|
self.config["polygonize_params"], accumulated_tile_data["seg"], |
|
crossfield_batch=crossfield, |
|
pool=pool) |
|
|
|
|
|
if self.config["eval_params"]["save_individual_outputs"]["seg_mask"] or \ |
|
self.config["eval_params"]["save_aggregated_outputs"]["seg_coco"]: |
|
|
|
seg_pred_mask = self.config["eval_params"]["seg_threshold"] < accumulated_tile_data["seg"][:, 0, ...] |
|
accumulated_tile_data["seg_mask"] = seg_pred_mask |
|
|
|
accumulated_tile_data = local_utils.batch_to_cpu(accumulated_tile_data) |
|
sample_list = local_utils.split_batch(accumulated_tile_data) |
|
|
|
|
|
if save_individual_outputs: |
|
for sample in sample_list: |
|
saver_async.add_work(sample) |
|
|
|
|
|
if save_aggregated_outputs: |
|
self.shared_dict["name_list"].extend(accumulated_tile_data["name"]) |
|
if self.config["eval_params"]["save_aggregated_outputs"]["stats"]: |
|
y_pred = accumulated_tile_data["seg"][:, 0, ...].cpu() |
|
if "gt_mask" in accumulated_tile_data: |
|
y_true = accumulated_tile_data["gt_mask"][:, 0, ...] |
|
elif "gt_polygons_image" in accumulated_tile_data: |
|
y_true = accumulated_tile_data["gt_polygons_image"][:, 0, ...] |
|
else: |
|
raise ValueError("Either gt_mask or gt_polygons_image should be in accumulated_tile_data") |
|
iou = measures.iou(y_pred.reshape(y_pred.shape[0], -1), y_true.reshape(y_true.shape[0], -1), |
|
threshold=self.config["eval_params"]["seg_threshold"]) |
|
self.shared_dict["iou_list"].extend(iou.cpu().numpy()) |
|
if self.config["eval_params"]["save_aggregated_outputs"]["seg_coco"]: |
|
for sample in sample_list: |
|
annotations = save_utils.seg_coco(sample) |
|
self.shared_dict["seg_coco_list"].extend(annotations) |
|
if self.config["eval_params"]["save_aggregated_outputs"]["poly_coco"]: |
|
for sample in sample_list: |
|
annotations = save_utils.poly_coco(sample["polygons"], sample["polygon_probs"], sample["image_id"].item()) |
|
self.shared_dict["poly_coco_list"].append(annotations) |
|
|
|
|
|
|
|
if save_aggregated_outputs: |
|
self.barrier.wait() |
|
if self.gpu == 0: |
|
if self.config["eval_params"]["save_aggregated_outputs"]["stats"]: |
|
print("Start saving stats:") |
|
|
|
t1 = time.time() |
|
stats_filepath = os.path.join(self.eval_dirpath, "{}.stats.csv".format(split_name)) |
|
stats_file = open(stats_filepath, "w") |
|
fnames = ["name", "iou"] |
|
writer = csv.DictWriter(stats_file, fieldnames=fnames) |
|
writer.writeheader() |
|
for name, iou in sorted(zip(self.shared_dict["name_list"], self.shared_dict["iou_list"]), key=lambda pair: pair[0]): |
|
writer.writerow({ |
|
"name": name, |
|
"iou": iou |
|
}) |
|
stats_file.close() |
|
print(f"Finished in {time.time() - t1:02}s") |
|
|
|
if self.config["eval_params"]["save_aggregated_outputs"]["seg_coco"]: |
|
print("Start saving seg_coco:") |
|
t1 = time.time() |
|
seg_coco_filepath = os.path.join(self.eval_dirpath, "{}.annotation.seg.json".format(split_name)) |
|
python_utils.save_json(seg_coco_filepath, list(self.shared_dict["seg_coco_list"])) |
|
print(f"Finished in {time.time() - t1:02}s") |
|
|
|
if self.config["eval_params"]["save_aggregated_outputs"]["poly_coco"]: |
|
print("Start saving poly_coco:") |
|
poly_coco_base_filepath = os.path.join(self.eval_dirpath, f"{split_name}.annotation.poly") |
|
t1 = time.time() |
|
save_utils.save_poly_coco(self.shared_dict["poly_coco_list"], poly_coco_base_filepath) |
|
print(f"Finished in {time.time() - t1:02}s") |
|
|
|
|
|
if save_individual_outputs: |
|
print_utils.print_info(f"GPU {self.gpu} -> INFO: Finishing saving individual outputs.") |
|
saver_async.join() |
|
self.barrier.wait() |
|
|
|
def load_checkpoint(self): |
|
""" |
|
Loads best val checkpoint in checkpoints_dirpath |
|
""" |
|
filepaths = python_utils.get_filepaths(self.checkpoints_dirpath, startswith_str="checkpoint.best_val.", |
|
endswith_str=".tar") |
|
if len(filepaths): |
|
filepaths = sorted(filepaths) |
|
filepath = filepaths[-1] |
|
if self.gpu == 0: |
|
print_utils.print_info("Loading best val checkpoint: {}".format(filepath)) |
|
else: |
|
|
|
filepaths = python_utils.get_filepaths(self.checkpoints_dirpath, endswith_str=".tar", |
|
startswith_str="checkpoint.") |
|
if len(filepaths) == 0: |
|
raise FileNotFoundError("No checkpoint could be found at that location.") |
|
filepaths = sorted(filepaths) |
|
filepath = filepaths[-1] |
|
if self.gpu == 0: |
|
print_utils.print_info("Loading last checkpoint: {}".format(filepath)) |
|
|
|
checkpoint = torch.load(filepath, map_location="cuda:{}".format(self.gpu)) |
|
|
|
self.model.module.load_state_dict(checkpoint['model_state_dict']) |
|
|