File size: 5,590 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Convert the aggregation results from the homography adaptation to GT labels.
"""
import sys
sys.path.append("../")
import os
import yaml
import argparse
import numpy as np
import h5py
import torch
from tqdm import tqdm

from config.project_config import Config as cfg
from model.line_detection import LineSegmentDetectionModule
from model.metrics import super_nms
from misc.train_utils import parse_h5_data


def convert_raw_exported_predictions(input_data, grid_size=8,
                                     detect_thresh=1/65, topk=300):
    """ Convert the exported junctions and heatmaps predictions
        to a standard format.
    Arguments:
        input_data: the raw data (dict) decoded from the hdf5 dataset
        outputs: dict containing required entries including:
            junctions_pred: Nx2 ndarray containing nms junction predictions.
            heatmap_pred: HxW ndarray containing predicted heatmaps
            valid_mask: HxW ndarray containing the valid mask
    """
    # Check the input_data is from (1) single prediction,
    # or (2) homography adaptation.
    # Homography adaptation raw predictions
    if (("junc_prob_mean" in input_data.keys())
        and ("heatmap_prob_mean" in input_data.keys())):
        # Get the junction predictions and convert if to Nx2 format
        junc_prob = input_data["junc_prob_mean"]
        junc_pred_np = junc_prob[None, ...]
        junc_pred_np_nms = super_nms(junc_pred_np, grid_size,
                                     detect_thresh, topk)
        junctions = np.where(junc_pred_np_nms.squeeze())
        junc_points_pred = np.concatenate([junctions[0][..., None],
                                           junctions[1][..., None]], axis=-1)

        # Get the heatmap predictions
        heatmap_pred = input_data["heatmap_prob_mean"].squeeze()
        valid_mask = np.ones(heatmap_pred.shape, dtype=np.int32)
        
    # Single predictions
    else:
        # Get the junction point predictions and convert to Nx2 format
        junc_points_pred = np.where(input_data["junc_pred_nms"])
        junc_points_pred = np.concatenate(
            [junc_points_pred[0][..., None],
             junc_points_pred[1][..., None]], axis=-1)

        # Get the heatmap predictions
        heatmap_pred = input_data["heatmap_pred"]
        valid_mask = input_data["valid_mask"]

    return {
        "junctions_pred": junc_points_pred,
        "heatmap_pred": heatmap_pred,
        "valid_mask": valid_mask
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("input_dataset", type=str,
                        help="Name of the exported dataset.")
    parser.add_argument("output_dataset", type=str,
                        help="Name of the output dataset.")
    parser.add_argument("config", type=str,
                        help="Path to the model config.")
    args = parser.parse_args()    
    
    # Define the path to the input exported dataset
    exported_dataset_path = os.path.join(cfg.export_dataroot,
                                         args.input_dataset)
    if not os.path.exists(exported_dataset_path):
        raise ValueError("Missing input dataset: " + exported_dataset_path)
    exported_dataset = h5py.File(exported_dataset_path, "r")

    # Define the output path for the results
    output_dataset_path = os.path.join(cfg.export_dataroot,
                                       args.output_dataset)

    device = torch.device("cuda")
    nms_device = torch.device("cuda")
    
    # Read the config file
    if not os.path.exists(args.config):
        raise ValueError("Missing config file: " + args.config)
    with open(args.config, "r") as f:
        config = yaml.safe_load(f)
    model_cfg = config["model_cfg"]
    line_detector_cfg = config["line_detector_cfg"]
    
    # Initialize the line detection module
    line_detector = LineSegmentDetectionModule(**line_detector_cfg)

    # Iterate through all the dataset keys
    with h5py.File(output_dataset_path, "w") as output_dataset:
        for idx, output_key in enumerate(tqdm(list(exported_dataset.keys()),
                                              ascii=True)):
            # Get the data
            data = parse_h5_data(exported_dataset[output_key])

            # Preprocess the data
            converted_data = convert_raw_exported_predictions(
                data, grid_size=model_cfg["grid_size"],
                detect_thresh=model_cfg["detection_thresh"])
            junctions_pred_raw = converted_data["junctions_pred"]
            heatmap_pred = converted_data["heatmap_pred"]
            valid_mask = converted_data["valid_mask"]

            line_map_pred, junctions_pred, heatmap_pred = line_detector.detect(
                junctions_pred_raw, heatmap_pred, device=device)
            if isinstance(line_map_pred, torch.Tensor):
                line_map_pred = line_map_pred.cpu().numpy()
            if isinstance(junctions_pred, torch.Tensor):
                junctions_pred = junctions_pred.cpu().numpy()
            if isinstance(heatmap_pred, torch.Tensor):
                heatmap_pred = heatmap_pred.cpu().numpy()
            
            output_data = {"junctions": junctions_pred,
                           "line_map": line_map_pred}

            # Record it to the h5 dataset
            f_group = output_dataset.create_group(output_key)

            # Store data
            for key, output_data in output_data.items():
                f_group.create_dataset(key, data=output_data,
                                       compression="gzip")