File size: 2,887 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
c74a070
 
 
a80d6bb
c74a070
a80d6bb
c74a070
a80d6bb
 
 
c74a070
 
 
 
 
 
 
a80d6bb
 
 
 
 
c74a070
a80d6bb
 
 
 
 
 
 
 
 
c74a070
 
 
a80d6bb
 
 
 
 
c74a070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
c74a070
a80d6bb
 
c74a070
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
    Export line detections and descriptors given a list of input images.
"""
import os
import argparse
import cv2
import numpy as np
import torch
from tqdm import tqdm

from .experiment import load_config
from .model.line_matcher import LineMatcher


def export_descriptors(
    images_list, ckpt_path, config, device, extension, output_folder, multiscale=False
):
    # Extract the image paths
    with open(images_list, "r") as f:
        image_files = f.readlines()
    image_files = [path.strip("\n") for path in image_files]

    # Initialize the line matcher
    line_matcher = LineMatcher(
        config["model_cfg"],
        ckpt_path,
        device,
        config["line_detector_cfg"],
        config["line_matcher_cfg"],
        multiscale,
    )
    print("\t Successfully initialized model")

    # Run the inference on each image and write the output on disk
    for img_path in tqdm(image_files):
        img = cv2.imread(img_path, 0)
        img = torch.tensor(img[None, None] / 255.0, dtype=torch.float, device=device)

        # Run the line detection and description
        ref_detection = line_matcher.line_detection(img)
        ref_line_seg = ref_detection["line_segments"]
        ref_descriptors = ref_detection["descriptor"][0].cpu().numpy()

        # Write the output on disk
        img_name = os.path.splitext(os.path.basename(img_path))[0]
        output_file = os.path.join(output_folder, img_name + extension)
        np.savez_compressed(
            output_file, line_seg=ref_line_seg, descriptors=ref_descriptors
        )


if __name__ == "__main__":
    # Parse input arguments
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--img_list",
        type=str,
        required=True,
        help="List of input images in a text file.",
    )
    parser.add_argument(
        "--output_folder", type=str, required=True, help="Path to the output folder."
    )
    parser.add_argument(
        "--config", type=str, default="config/export_line_features.yaml"
    )
    parser.add_argument(
        "--checkpoint_path", type=str, default="pretrained_models/sold2_wireframe.tar"
    )
    parser.add_argument("--multiscale", action="store_true", default=False)
    parser.add_argument("--extension", type=str, default=None)
    args = parser.parse_args()

    # Get the device
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    # Get the model config, extension and checkpoint path
    config = load_config(args.config)
    ckpt_path = os.path.abspath(args.checkpoint_path)
    extension = "sold2" if args.extension is None else args.extension
    extension = "." + extension

    export_descriptors(
        args.img_list,
        ckpt_path,
        config,
        device,
        extension,
        args.output_folder,
        args.multiscale,
    )