File size: 5,553 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c74a070
 
 
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
c74a070
 
 
 
a80d6bb
 
 
 
 
 
 
c74a070
 
 
 
 
 
a80d6bb
 
 
c74a070
 
 
 
 
 
a80d6bb
c74a070
 
 
 
a80d6bb
c74a070
 
 
 
a80d6bb
 
 
 
 
 
 
 
c74a070
a80d6bb
 
 
c74a070
a80d6bb
 
 
c74a070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a80d6bb
 
 
 
 
 
 
c74a070
 
 
 
 
 
 
a80d6bb
 
c74a070
 
 
 
 
a80d6bb
 
 
 
 
 
 
c74a070
 
 
 
 
 
a80d6bb
c74a070
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python
# coding: utf-8

import os, glob, cv2
import argparse
from argparse import Namespace
import yaml
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler

from src.datasets.custom_dataloader import TestDataLoader
from src.utils.dataset import read_img_gray
from configs.data.base import cfg as data_cfg
import viz


def get_model_config(method_name, dataset_name, root_dir="viz"):
    config_file = f"{root_dir}/configs/{method_name}.yml"
    with open(config_file, "r") as f:
        model_conf = yaml.load(f, Loader=yaml.FullLoader)[dataset_name]
    return model_conf


class DemoDataset(Dataset):
    def __init__(self, dataset_dir, img_file=None, resize=0, down_factor=16):
        self.dataset_dir = dataset_dir
        if img_file is None:
            self.list_img_files = glob.glob(os.path.join(dataset_dir, "*.*"))
            self.list_img_files.sort()
        else:
            with open(img_file) as f:
                self.list_img_files = [
                    os.path.join(dataset_dir, img_file.strip())
                    for img_file in f.readlines()
                ]
        self.resize = resize
        self.down_factor = down_factor

    def __len__(self):
        return len(self.list_img_files)

    def __getitem__(self, idx):
        img_path = self.list_img_files[
            idx
        ]  # os.path.join(self.dataset_dir, self.list_img_files[idx])
        img, scale = read_img_gray(
            img_path, resize=self.resize, down_factor=self.down_factor
        )
        return {"img": img, "id": idx, "img_path": img_path}


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Visualize matches")
    parser.add_argument("--gpu", "-gpu", type=str, default="0")
    parser.add_argument("--method", type=str, default=None)
    parser.add_argument("--dataset_dir", type=str, default="data/aachen-day-night")
    parser.add_argument("--pair_dir", type=str, default=None)
    parser.add_argument(
        "--dataset_name",
        type=str,
        choices=["megadepth", "scannet", "aachen_v1.1", "inloc"],
        default="megadepth",
    )
    parser.add_argument("--measure_time", action="store_true")
    parser.add_argument("--no_viz", action="store_true")
    parser.add_argument("--compute_eval_metrics", action="store_true")
    parser.add_argument("--run_demo", action="store_true")

    args = parser.parse_args()

    model_cfg = get_model_config(args.method, args.dataset_name)
    class_name = model_cfg["class"]
    model = viz.__dict__[class_name](model_cfg)
    # all_args = Namespace(**vars(args), **model_cfg)
    if not args.run_demo:
        if args.dataset_name == "megadepth":
            from configs.data.megadepth_test_1500 import cfg

            data_cfg.merge_from_other_cfg(cfg)
        elif args.dataset_name == "scannet":
            from configs.data.scannet_test_1500 import cfg

            data_cfg.merge_from_other_cfg(cfg)
        elif args.dataset_name == "aachen_v1.1":
            data_cfg.merge_from_list(
                [
                    "DATASET.TEST_DATA_SOURCE",
                    "aachen_v1.1",
                    "DATASET.TEST_DATA_ROOT",
                    os.path.join(args.dataset_dir, "images/images_upright"),
                    "DATASET.TEST_LIST_PATH",
                    args.pair_dir,
                    "DATASET.TEST_IMGSIZE",
                    model_cfg["imsize"],
                ]
            )
        elif args.dataset_name == "inloc":
            data_cfg.merge_from_list(
                [
                    "DATASET.TEST_DATA_SOURCE",
                    "inloc",
                    "DATASET.TEST_DATA_ROOT",
                    args.dataset_dir,
                    "DATASET.TEST_LIST_PATH",
                    args.pair_dir,
                    "DATASET.TEST_IMGSIZE",
                    model_cfg["imsize"],
                ]
            )

        has_ground_truth = str(data_cfg.DATASET.TEST_DATA_SOURCE).lower() in [
            "megadepth",
            "scannet",
        ]
        dataloader = TestDataLoader(data_cfg)
        with torch.no_grad():
            for data_dict in tqdm(dataloader):
                for k, v in data_dict.items():
                    if isinstance(v, torch.Tensor):
                        data_dict[k] = v.cuda() if torch.cuda.is_available() else v
                img_root_dir = data_cfg.DATASET.TEST_DATA_ROOT
                model.match_and_draw(
                    data_dict,
                    root_dir=img_root_dir,
                    ground_truth=has_ground_truth,
                    measure_time=args.measure_time,
                    viz_matches=(not args.no_viz),
                )

        if args.measure_time:
            print(
                "Running time for each image is {} miliseconds".format(
                    model.measure_time()
                )
            )
        if args.compute_eval_metrics and has_ground_truth:
            model.compute_eval_metrics()
    else:
        demo_dataset = DemoDataset(args.dataset_dir, img_file=args.pair_dir, resize=640)
        sampler = SequentialSampler(demo_dataset)
        dataloader = DataLoader(demo_dataset, batch_size=1, sampler=sampler)

        writer = cv2.VideoWriter(
            "topicfm_demo.mp4",
            cv2.VideoWriter_fourcc(*"mp4v"),
            15,
            (640 * 2 + 5, 480 * 2 + 10),
        )

        model.run_demo(
            iter(dataloader), writer
        )  # , output_dir="demo", no_display=True)