#!/usr/bin/env python # coding: utf-8 import os, glob, cv2 import argparse from argparse import Namespace import yaml from tqdm import tqdm import torch from torch.utils.data import Dataset, DataLoader, SequentialSampler from src.datasets.custom_dataloader import TestDataLoader from src.utils.dataset import read_img_gray from configs.data.base import cfg as data_cfg import viz def get_model_config(method_name, dataset_name, root_dir="viz"): config_file = f"{root_dir}/configs/{method_name}.yml" with open(config_file, "r") as f: model_conf = yaml.load(f, Loader=yaml.FullLoader)[dataset_name] return model_conf class DemoDataset(Dataset): def __init__(self, dataset_dir, img_file=None, resize=0, down_factor=16): self.dataset_dir = dataset_dir if img_file is None: self.list_img_files = glob.glob(os.path.join(dataset_dir, "*.*")) self.list_img_files.sort() else: with open(img_file) as f: self.list_img_files = [ os.path.join(dataset_dir, img_file.strip()) for img_file in f.readlines() ] self.resize = resize self.down_factor = down_factor def __len__(self): return len(self.list_img_files) def __getitem__(self, idx): img_path = self.list_img_files[ idx ] # os.path.join(self.dataset_dir, self.list_img_files[idx]) img, scale = read_img_gray( img_path, resize=self.resize, down_factor=self.down_factor ) return {"img": img, "id": idx, "img_path": img_path} if __name__ == "__main__": parser = argparse.ArgumentParser(description="Visualize matches") parser.add_argument("--gpu", "-gpu", type=str, default="0") parser.add_argument("--method", type=str, default=None) parser.add_argument("--dataset_dir", type=str, default="data/aachen-day-night") parser.add_argument("--pair_dir", type=str, default=None) parser.add_argument( "--dataset_name", type=str, choices=["megadepth", "scannet", "aachen_v1.1", "inloc"], default="megadepth", ) parser.add_argument("--measure_time", action="store_true") parser.add_argument("--no_viz", action="store_true") parser.add_argument("--compute_eval_metrics", action="store_true") parser.add_argument("--run_demo", action="store_true") args = parser.parse_args() model_cfg = get_model_config(args.method, args.dataset_name) class_name = model_cfg["class"] model = viz.__dict__[class_name](model_cfg) # all_args = Namespace(**vars(args), **model_cfg) if not args.run_demo: if args.dataset_name == "megadepth": from configs.data.megadepth_test_1500 import cfg data_cfg.merge_from_other_cfg(cfg) elif args.dataset_name == "scannet": from configs.data.scannet_test_1500 import cfg data_cfg.merge_from_other_cfg(cfg) elif args.dataset_name == "aachen_v1.1": data_cfg.merge_from_list( [ "DATASET.TEST_DATA_SOURCE", "aachen_v1.1", "DATASET.TEST_DATA_ROOT", os.path.join(args.dataset_dir, "images/images_upright"), "DATASET.TEST_LIST_PATH", args.pair_dir, "DATASET.TEST_IMGSIZE", model_cfg["imsize"], ] ) elif args.dataset_name == "inloc": data_cfg.merge_from_list( [ "DATASET.TEST_DATA_SOURCE", "inloc", "DATASET.TEST_DATA_ROOT", args.dataset_dir, "DATASET.TEST_LIST_PATH", args.pair_dir, "DATASET.TEST_IMGSIZE", model_cfg["imsize"], ] ) has_ground_truth = str(data_cfg.DATASET.TEST_DATA_SOURCE).lower() in [ "megadepth", "scannet", ] dataloader = TestDataLoader(data_cfg) with torch.no_grad(): for data_dict in tqdm(dataloader): for k, v in data_dict.items(): if isinstance(v, torch.Tensor): data_dict[k] = v.cuda() if torch.cuda.is_available() else v img_root_dir = data_cfg.DATASET.TEST_DATA_ROOT model.match_and_draw( data_dict, root_dir=img_root_dir, ground_truth=has_ground_truth, measure_time=args.measure_time, viz_matches=(not args.no_viz), ) if args.measure_time: print( "Running time for each image is {} miliseconds".format( model.measure_time() ) ) if args.compute_eval_metrics and has_ground_truth: model.compute_eval_metrics() else: demo_dataset = DemoDataset(args.dataset_dir, img_file=args.pair_dir, resize=640) sampler = SequentialSampler(demo_dataset) dataloader = DataLoader(demo_dataset, batch_size=1, sampler=sampler) writer = cv2.VideoWriter( "topicfm_demo.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 15, (640 * 2 + 5, 480 * 2 + 10), ) model.run_demo( iter(dataloader), writer ) # , output_dir="demo", no_display=True)