""" Convert the aggregation results from the homography adaptation to GT labels. """ import sys sys.path.append("../") import os import yaml import argparse import numpy as np import h5py import torch from tqdm import tqdm from config.project_config import Config as cfg from model.line_detection import LineSegmentDetectionModule from model.metrics import super_nms from misc.train_utils import parse_h5_data def convert_raw_exported_predictions( input_data, grid_size=8, detect_thresh=1 / 65, topk=300 ): """Convert the exported junctions and heatmaps predictions to a standard format. Arguments: input_data: the raw data (dict) decoded from the hdf5 dataset outputs: dict containing required entries including: junctions_pred: Nx2 ndarray containing nms junction predictions. heatmap_pred: HxW ndarray containing predicted heatmaps valid_mask: HxW ndarray containing the valid mask """ # Check the input_data is from (1) single prediction, # or (2) homography adaptation. # Homography adaptation raw predictions if ("junc_prob_mean" in input_data.keys()) and ( "heatmap_prob_mean" in input_data.keys() ): # Get the junction predictions and convert if to Nx2 format junc_prob = input_data["junc_prob_mean"] junc_pred_np = junc_prob[None, ...] junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk) junctions = np.where(junc_pred_np_nms.squeeze()) junc_points_pred = np.concatenate( [junctions[0][..., None], junctions[1][..., None]], axis=-1 ) # Get the heatmap predictions heatmap_pred = input_data["heatmap_prob_mean"].squeeze() valid_mask = np.ones(heatmap_pred.shape, dtype=np.int32) # Single predictions else: # Get the junction point predictions and convert to Nx2 format junc_points_pred = np.where(input_data["junc_pred_nms"]) junc_points_pred = np.concatenate( [junc_points_pred[0][..., None], junc_points_pred[1][..., None]], axis=-1 ) # Get the heatmap predictions heatmap_pred = input_data["heatmap_pred"] valid_mask = input_data["valid_mask"] return { "junctions_pred": junc_points_pred, "heatmap_pred": heatmap_pred, "valid_mask": valid_mask, } if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("input_dataset", type=str, help="Name of the exported dataset.") parser.add_argument("output_dataset", type=str, help="Name of the output dataset.") parser.add_argument("config", type=str, help="Path to the model config.") args = parser.parse_args() # Define the path to the input exported dataset exported_dataset_path = os.path.join(cfg.export_dataroot, args.input_dataset) if not os.path.exists(exported_dataset_path): raise ValueError("Missing input dataset: " + exported_dataset_path) exported_dataset = h5py.File(exported_dataset_path, "r") # Define the output path for the results output_dataset_path = os.path.join(cfg.export_dataroot, args.output_dataset) device = torch.device("cuda") nms_device = torch.device("cuda") # Read the config file if not os.path.exists(args.config): raise ValueError("Missing config file: " + args.config) with open(args.config, "r") as f: config = yaml.safe_load(f) model_cfg = config["model_cfg"] line_detector_cfg = config["line_detector_cfg"] # Initialize the line detection module line_detector = LineSegmentDetectionModule(**line_detector_cfg) # Iterate through all the dataset keys with h5py.File(output_dataset_path, "w") as output_dataset: for idx, output_key in enumerate( tqdm(list(exported_dataset.keys()), ascii=True) ): # Get the data data = parse_h5_data(exported_dataset[output_key]) # Preprocess the data converted_data = convert_raw_exported_predictions( data, grid_size=model_cfg["grid_size"], detect_thresh=model_cfg["detection_thresh"], ) junctions_pred_raw = converted_data["junctions_pred"] heatmap_pred = converted_data["heatmap_pred"] valid_mask = converted_data["valid_mask"] line_map_pred, junctions_pred, heatmap_pred = line_detector.detect( junctions_pred_raw, heatmap_pred, device=device ) if isinstance(line_map_pred, torch.Tensor): line_map_pred = line_map_pred.cpu().numpy() if isinstance(junctions_pred, torch.Tensor): junctions_pred = junctions_pred.cpu().numpy() if isinstance(heatmap_pred, torch.Tensor): heatmap_pred = heatmap_pred.cpu().numpy() output_data = {"junctions": junctions_pred, "line_map": line_map_pred} # Record it to the h5 dataset f_group = output_dataset.create_group(output_key) # Store data for key, output_data in output_data.items(): f_group.create_dataset(key, data=output_data, compression="gzip")