import torch import yaml import time from collections import OrderedDict, namedtuple import os import sys ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) from sgmnet import matcher as SGM_Model from superglue import matcher as SG_Model import argparse parser = argparse.ArgumentParser() parser.add_argument( "--matcher_name", type=str, default="SGM", help="number of processes." ) parser.add_argument( "--config_path", type=str, default="configs/cost/sgm_cost.yaml", help="number of processes.", ) parser.add_argument( "--num_kpt", type=int, default=4000, help="keypoint number, default:100" ) parser.add_argument( "--iter_num", type=int, default=100, help="keypoint number, default:100" ) def test_cost(test_data, model): with torch.no_grad(): # warm up call _ = model(test_data) torch.cuda.synchronize() a = time.time() for _ in range(int(args.iter_num)): _ = model(test_data) torch.cuda.synchronize() b = time.time() print("Average time per run(ms): ", (b - a) / args.iter_num * 1e3) print("Peak memory(MB): ", torch.cuda.max_memory_allocated() / 1e6) if __name__ == "__main__": torch.backends.cudnn.benchmark = False args = parser.parse_args() with open(args.config_path, "r") as f: model_config = yaml.load(f) model_config = namedtuple("model_config", model_config.keys())( *model_config.values() ) if args.matcher_name == "SGM": model = SGM_Model(model_config) elif args.matcher_name == "SG": model = SG_Model(model_config) model.cuda(), model.eval() test_data = { "x1": torch.rand(1, args.num_kpt, 2).cuda() - 0.5, "x2": torch.rand(1, args.num_kpt, 2).cuda() - 0.5, "desc1": torch.rand(1, args.num_kpt, 128).cuda(), "desc2": torch.rand(1, args.num_kpt, 128).cuda(), } test_cost(test_data, model)