""" Export line detections and descriptors given a list of input images. """ import os import argparse import cv2 import numpy as np import torch from tqdm import tqdm from .experiment import load_config from .model.line_matcher import LineMatcher def export_descriptors( images_list, ckpt_path, config, device, extension, output_folder, multiscale=False ): # Extract the image paths with open(images_list, "r") as f: image_files = f.readlines() image_files = [path.strip("\n") for path in image_files] # Initialize the line matcher line_matcher = LineMatcher( config["model_cfg"], ckpt_path, device, config["line_detector_cfg"], config["line_matcher_cfg"], multiscale, ) print("\t Successfully initialized model") # Run the inference on each image and write the output on disk for img_path in tqdm(image_files): img = cv2.imread(img_path, 0) img = torch.tensor(img[None, None] / 255.0, dtype=torch.float, device=device) # Run the line detection and description ref_detection = line_matcher.line_detection(img) ref_line_seg = ref_detection["line_segments"] ref_descriptors = ref_detection["descriptor"][0].cpu().numpy() # Write the output on disk img_name = os.path.splitext(os.path.basename(img_path))[0] output_file = os.path.join(output_folder, img_name + extension) np.savez_compressed( output_file, line_seg=ref_line_seg, descriptors=ref_descriptors ) if __name__ == "__main__": # Parse input arguments parser = argparse.ArgumentParser() parser.add_argument( "--img_list", type=str, required=True, help="List of input images in a text file.", ) parser.add_argument( "--output_folder", type=str, required=True, help="Path to the output folder." ) parser.add_argument( "--config", type=str, default="config/export_line_features.yaml" ) parser.add_argument( "--checkpoint_path", type=str, default="pretrained_models/sold2_wireframe.tar" ) parser.add_argument("--multiscale", action="store_true", default=False) parser.add_argument("--extension", type=str, default=None) args = parser.parse_args() # Get the device if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") # Get the model config, extension and checkpoint path config = load_config(args.config) ckpt_path = os.path.abspath(args.checkpoint_path) extension = "sold2" if args.extension is None else args.extension extension = "." + extension export_descriptors( args.img_list, ckpt_path, config, device, extension, args.output_folder, args.multiscale, )