gim-online / third_party /SOLD2 /sold2 /postprocess /convert_homography_results.py
Vincentqyw
fix: roma
8b973ee
raw history blame
No virus
5.27 kB
"""
Convert the aggregation results from the homography adaptation to GT labels.
"""
import sys
sys.path.append("../")
import os
import yaml
import argparse
import numpy as np
import h5py
import torch
from tqdm import tqdm
from config.project_config import Config as cfg
from model.line_detection import LineSegmentDetectionModule
from model.metrics import super_nms
from misc.train_utils import parse_h5_data
def convert_raw_exported_predictions(
input_data, grid_size=8, detect_thresh=1 / 65, topk=300
):
"""Convert the exported junctions and heatmaps predictions
to a standard format.
Arguments:
input_data: the raw data (dict) decoded from the hdf5 dataset
outputs: dict containing required entries including:
junctions_pred: Nx2 ndarray containing nms junction predictions.
heatmap_pred: HxW ndarray containing predicted heatmaps
valid_mask: HxW ndarray containing the valid mask
"""
# Check the input_data is from (1) single prediction,
# or (2) homography adaptation.
# Homography adaptation raw predictions
if ("junc_prob_mean" in input_data.keys()) and (
"heatmap_prob_mean" in input_data.keys()
):
# Get the junction predictions and convert if to Nx2 format
junc_prob = input_data["junc_prob_mean"]
junc_pred_np = junc_prob[None, ...]
junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk)
junctions = np.where(junc_pred_np_nms.squeeze())
junc_points_pred = np.concatenate(
[junctions[0][..., None], junctions[1][..., None]], axis=-1
)
# Get the heatmap predictions
heatmap_pred = input_data["heatmap_prob_mean"].squeeze()
valid_mask = np.ones(heatmap_pred.shape, dtype=np.int32)
# Single predictions
else:
# Get the junction point predictions and convert to Nx2 format
junc_points_pred = np.where(input_data["junc_pred_nms"])
junc_points_pred = np.concatenate(
[junc_points_pred[0][..., None], junc_points_pred[1][..., None]], axis=-1
)
# Get the heatmap predictions
heatmap_pred = input_data["heatmap_pred"]
valid_mask = input_data["valid_mask"]
return {
"junctions_pred": junc_points_pred,
"heatmap_pred": heatmap_pred,
"valid_mask": valid_mask,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_dataset", type=str, help="Name of the exported dataset.")
parser.add_argument("output_dataset", type=str, help="Name of the output dataset.")
parser.add_argument("config", type=str, help="Path to the model config.")
args = parser.parse_args()
# Define the path to the input exported dataset
exported_dataset_path = os.path.join(cfg.export_dataroot, args.input_dataset)
if not os.path.exists(exported_dataset_path):
raise ValueError("Missing input dataset: " + exported_dataset_path)
exported_dataset = h5py.File(exported_dataset_path, "r")
# Define the output path for the results
output_dataset_path = os.path.join(cfg.export_dataroot, args.output_dataset)
device = torch.device("cuda")
nms_device = torch.device("cuda")
# Read the config file
if not os.path.exists(args.config):
raise ValueError("Missing config file: " + args.config)
with open(args.config, "r") as f:
config = yaml.safe_load(f)
model_cfg = config["model_cfg"]
line_detector_cfg = config["line_detector_cfg"]
# Initialize the line detection module
line_detector = LineSegmentDetectionModule(**line_detector_cfg)
# Iterate through all the dataset keys
with h5py.File(output_dataset_path, "w") as output_dataset:
for idx, output_key in enumerate(
tqdm(list(exported_dataset.keys()), ascii=True)
):
# Get the data
data = parse_h5_data(exported_dataset[output_key])
# Preprocess the data
converted_data = convert_raw_exported_predictions(
data,
grid_size=model_cfg["grid_size"],
detect_thresh=model_cfg["detection_thresh"],
)
junctions_pred_raw = converted_data["junctions_pred"]
heatmap_pred = converted_data["heatmap_pred"]
valid_mask = converted_data["valid_mask"]
line_map_pred, junctions_pred, heatmap_pred = line_detector.detect(
junctions_pred_raw, heatmap_pred, device=device
)
if isinstance(line_map_pred, torch.Tensor):
line_map_pred = line_map_pred.cpu().numpy()
if isinstance(junctions_pred, torch.Tensor):
junctions_pred = junctions_pred.cpu().numpy()
if isinstance(heatmap_pred, torch.Tensor):
heatmap_pred = heatmap_pred.cpu().numpy()
output_data = {"junctions": junctions_pred, "line_map": line_map_pred}
# Record it to the h5 dataset
f_group = output_dataset.create_group(output_key)
# Store data
for key, output_data in output_data.items():
f_group.create_dataset(key, data=output_data, compression="gzip")