File size: 4,966 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python
# coding: utf-8

import os, glob, cv2
import argparse
from argparse import Namespace
import yaml
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler

from src.datasets.custom_dataloader import TestDataLoader
from src.utils.dataset import read_img_gray
from configs.data.base import cfg as data_cfg
import viz


def get_model_config(method_name, dataset_name, root_dir='viz'):
    config_file = f'{root_dir}/configs/{method_name}.yml'
    with open(config_file, 'r') as f:
        model_conf = yaml.load(f, Loader=yaml.FullLoader)[dataset_name]
    return model_conf


class DemoDataset(Dataset):
    def __init__(self, dataset_dir, img_file=None, resize=0, down_factor=16):
        self.dataset_dir = dataset_dir
        if img_file is None:
            self.list_img_files = glob.glob(os.path.join(dataset_dir, "*.*"))
            self.list_img_files.sort()
        else:
            with open(img_file) as f:
                self.list_img_files = [os.path.join(dataset_dir, img_file.strip()) for img_file in f.readlines()]
        self.resize = resize
        self.down_factor = down_factor

    def __len__(self):
        return len(self.list_img_files)

    def __getitem__(self, idx):
        img_path = self.list_img_files[idx] #os.path.join(self.dataset_dir, self.list_img_files[idx])
        img, scale = read_img_gray(img_path, resize=self.resize, down_factor=self.down_factor)
        return {"img": img, "id": idx, "img_path": img_path}


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Visualize matches')
    parser.add_argument('--gpu', '-gpu', type=str, default='0')
    parser.add_argument('--method', type=str, default=None)
    parser.add_argument('--dataset_dir', type=str, default='data/aachen-day-night')
    parser.add_argument('--pair_dir', type=str, default=None)
    parser.add_argument(
        '--dataset_name', type=str, choices=['megadepth', 'scannet', 'aachen_v1.1', 'inloc'], default='megadepth'
    )
    parser.add_argument('--measure_time', action="store_true")
    parser.add_argument('--no_viz', action="store_true")
    parser.add_argument('--compute_eval_metrics', action="store_true")
    parser.add_argument('--run_demo', action="store_true")

    args = parser.parse_args()

    model_cfg = get_model_config(args.method, args.dataset_name)
    class_name = model_cfg["class"]
    model = viz.__dict__[class_name](model_cfg)
    # all_args = Namespace(**vars(args), **model_cfg)
    if not args.run_demo:
        if args.dataset_name == 'megadepth':
            from configs.data.megadepth_test_1500 import cfg

            data_cfg.merge_from_other_cfg(cfg)
        elif args.dataset_name == 'scannet':
            from configs.data.scannet_test_1500 import cfg

            data_cfg.merge_from_other_cfg(cfg)
        elif args.dataset_name == 'aachen_v1.1':
            data_cfg.merge_from_list(["DATASET.TEST_DATA_SOURCE", "aachen_v1.1",
                                      "DATASET.TEST_DATA_ROOT", os.path.join(args.dataset_dir, "images/images_upright"),
                                      "DATASET.TEST_LIST_PATH", args.pair_dir,
                                      "DATASET.TEST_IMGSIZE", model_cfg["imsize"]])
        elif args.dataset_name == 'inloc':
            data_cfg.merge_from_list(["DATASET.TEST_DATA_SOURCE", "inloc",
                                      "DATASET.TEST_DATA_ROOT", args.dataset_dir,
                                      "DATASET.TEST_LIST_PATH", args.pair_dir,
                                      "DATASET.TEST_IMGSIZE", model_cfg["imsize"]])

        has_ground_truth = str(data_cfg.DATASET.TEST_DATA_SOURCE).lower() in ["megadepth", "scannet"]
        dataloader = TestDataLoader(data_cfg)
        with torch.no_grad():
            for data_dict in tqdm(dataloader):
                for k, v in data_dict.items():
                    if isinstance(v, torch.Tensor):
                        data_dict[k] = v.cuda() if torch.cuda.is_available() else v
                img_root_dir = data_cfg.DATASET.TEST_DATA_ROOT
                model.match_and_draw(data_dict, root_dir=img_root_dir, ground_truth=has_ground_truth,
                                     measure_time=args.measure_time, viz_matches=(not args.no_viz))

        if args.measure_time:
            print("Running time for each image is {} miliseconds".format(model.measure_time()))
        if args.compute_eval_metrics and has_ground_truth:
            model.compute_eval_metrics()
    else:
        demo_dataset = DemoDataset(args.dataset_dir, img_file=args.pair_dir, resize=640)
        sampler = SequentialSampler(demo_dataset)
        dataloader = DataLoader(demo_dataset, batch_size=1, sampler=sampler)

        writer = cv2.VideoWriter('topicfm_demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 15, (640 * 2 + 5, 480 * 2 + 10))

        model.run_demo(iter(dataloader), writer) #, output_dir="demo", no_display=True)