Spaces:
Running
Running
File size: 5,269 Bytes
a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
"""
Convert the aggregation results from the homography adaptation to GT labels.
"""
import sys
sys.path.append("../")
import os
import yaml
import argparse
import numpy as np
import h5py
import torch
from tqdm import tqdm
from config.project_config import Config as cfg
from model.line_detection import LineSegmentDetectionModule
from model.metrics import super_nms
from misc.train_utils import parse_h5_data
def convert_raw_exported_predictions(
input_data, grid_size=8, detect_thresh=1 / 65, topk=300
):
"""Convert the exported junctions and heatmaps predictions
to a standard format.
Arguments:
input_data: the raw data (dict) decoded from the hdf5 dataset
outputs: dict containing required entries including:
junctions_pred: Nx2 ndarray containing nms junction predictions.
heatmap_pred: HxW ndarray containing predicted heatmaps
valid_mask: HxW ndarray containing the valid mask
"""
# Check the input_data is from (1) single prediction,
# or (2) homography adaptation.
# Homography adaptation raw predictions
if ("junc_prob_mean" in input_data.keys()) and (
"heatmap_prob_mean" in input_data.keys()
):
# Get the junction predictions and convert if to Nx2 format
junc_prob = input_data["junc_prob_mean"]
junc_pred_np = junc_prob[None, ...]
junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk)
junctions = np.where(junc_pred_np_nms.squeeze())
junc_points_pred = np.concatenate(
[junctions[0][..., None], junctions[1][..., None]], axis=-1
)
# Get the heatmap predictions
heatmap_pred = input_data["heatmap_prob_mean"].squeeze()
valid_mask = np.ones(heatmap_pred.shape, dtype=np.int32)
# Single predictions
else:
# Get the junction point predictions and convert to Nx2 format
junc_points_pred = np.where(input_data["junc_pred_nms"])
junc_points_pred = np.concatenate(
[junc_points_pred[0][..., None], junc_points_pred[1][..., None]], axis=-1
)
# Get the heatmap predictions
heatmap_pred = input_data["heatmap_pred"]
valid_mask = input_data["valid_mask"]
return {
"junctions_pred": junc_points_pred,
"heatmap_pred": heatmap_pred,
"valid_mask": valid_mask,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_dataset", type=str, help="Name of the exported dataset.")
parser.add_argument("output_dataset", type=str, help="Name of the output dataset.")
parser.add_argument("config", type=str, help="Path to the model config.")
args = parser.parse_args()
# Define the path to the input exported dataset
exported_dataset_path = os.path.join(cfg.export_dataroot, args.input_dataset)
if not os.path.exists(exported_dataset_path):
raise ValueError("Missing input dataset: " + exported_dataset_path)
exported_dataset = h5py.File(exported_dataset_path, "r")
# Define the output path for the results
output_dataset_path = os.path.join(cfg.export_dataroot, args.output_dataset)
device = torch.device("cuda")
nms_device = torch.device("cuda")
# Read the config file
if not os.path.exists(args.config):
raise ValueError("Missing config file: " + args.config)
with open(args.config, "r") as f:
config = yaml.safe_load(f)
model_cfg = config["model_cfg"]
line_detector_cfg = config["line_detector_cfg"]
# Initialize the line detection module
line_detector = LineSegmentDetectionModule(**line_detector_cfg)
# Iterate through all the dataset keys
with h5py.File(output_dataset_path, "w") as output_dataset:
for idx, output_key in enumerate(
tqdm(list(exported_dataset.keys()), ascii=True)
):
# Get the data
data = parse_h5_data(exported_dataset[output_key])
# Preprocess the data
converted_data = convert_raw_exported_predictions(
data,
grid_size=model_cfg["grid_size"],
detect_thresh=model_cfg["detection_thresh"],
)
junctions_pred_raw = converted_data["junctions_pred"]
heatmap_pred = converted_data["heatmap_pred"]
valid_mask = converted_data["valid_mask"]
line_map_pred, junctions_pred, heatmap_pred = line_detector.detect(
junctions_pred_raw, heatmap_pred, device=device
)
if isinstance(line_map_pred, torch.Tensor):
line_map_pred = line_map_pred.cpu().numpy()
if isinstance(junctions_pred, torch.Tensor):
junctions_pred = junctions_pred.cpu().numpy()
if isinstance(heatmap_pred, torch.Tensor):
heatmap_pred = heatmap_pred.cpu().numpy()
output_data = {"junctions": junctions_pred, "line_map": line_map_pred}
# Record it to the h5 dataset
f_group = output_dataset.create_group(output_key)
# Store data
for key, output_data in output_data.items():
f_group.create_dataset(key, data=output_data, compression="gzip")
|