Vincentqyw
fix: roma
8b973ee
raw
history blame
1.99 kB
import torch
import yaml
import time
from collections import OrderedDict, namedtuple
import os
import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT_DIR)
from sgmnet import matcher as SGM_Model
from superglue import matcher as SG_Model
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--matcher_name", type=str, default="SGM", help="number of processes."
)
parser.add_argument(
"--config_path",
type=str,
default="configs/cost/sgm_cost.yaml",
help="number of processes.",
)
parser.add_argument(
"--num_kpt", type=int, default=4000, help="keypoint number, default:100"
)
parser.add_argument(
"--iter_num", type=int, default=100, help="keypoint number, default:100"
)
def test_cost(test_data, model):
with torch.no_grad():
# warm up call
_ = model(test_data)
torch.cuda.synchronize()
a = time.time()
for _ in range(int(args.iter_num)):
_ = model(test_data)
torch.cuda.synchronize()
b = time.time()
print("Average time per run(ms): ", (b - a) / args.iter_num * 1e3)
print("Peak memory(MB): ", torch.cuda.max_memory_allocated() / 1e6)
if __name__ == "__main__":
torch.backends.cudnn.benchmark = False
args = parser.parse_args()
with open(args.config_path, "r") as f:
model_config = yaml.load(f)
model_config = namedtuple("model_config", model_config.keys())(
*model_config.values()
)
if args.matcher_name == "SGM":
model = SGM_Model(model_config)
elif args.matcher_name == "SG":
model = SG_Model(model_config)
model.cuda(), model.eval()
test_data = {
"x1": torch.rand(1, args.num_kpt, 2).cuda() - 0.5,
"x2": torch.rand(1, args.num_kpt, 2).cuda() - 0.5,
"desc1": torch.rand(1, args.num_kpt, 128).cuda(),
"desc2": torch.rand(1, args.num_kpt, 128).cuda(),
}
test_cost(test_data, model)