import torch import numpy as np import cv2 import os from loss import batch_episym from tqdm import tqdm import sys ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) from utils import evaluation_utils, train_utils def valid(valid_loader, model, match_loss, config, model_config): model.eval() loader_iter = iter(valid_loader) num_pair = 0 total_loss, total_acc_corr, total_acc_incorr = 0, 0, 0 total_precision, total_recall = torch.zeros( model_config.layer_num, device="cuda" ), torch.zeros(model_config.layer_num, device="cuda") total_acc_mid = torch.zeros(len(model_config.seedlayer) - 1, device="cuda") with torch.no_grad(): if config.local_rank == 0: loader_iter = tqdm(loader_iter) print("validating...") for test_data in loader_iter: num_pair += 1 test_data = train_utils.tocuda(test_data) res = model(test_data) loss_res = match_loss.run(test_data, res) total_acc_corr += loss_res["acc_corr"] total_acc_incorr += loss_res["acc_incorr"] total_loss += loss_res["total_loss"] if config.model_name == "SGM": total_acc_mid += loss_res["mid_acc_corr"] total_precision, total_recall = ( total_precision + loss_res["pre_seed_conf"], total_recall + loss_res["recall_seed_conf"], ) total_acc_corr /= num_pair total_acc_incorr /= num_pair total_precision /= num_pair total_recall /= num_pair total_acc_mid /= num_pair # apply tensor reduction ( total_loss, total_acc_corr, total_acc_incorr, total_precision, total_recall, total_acc_mid, ) = ( train_utils.reduce_tensor(total_loss, "sum"), train_utils.reduce_tensor(total_acc_corr, "mean"), train_utils.reduce_tensor(total_acc_incorr, "mean"), train_utils.reduce_tensor(total_precision, "mean"), train_utils.reduce_tensor(total_recall, "mean"), train_utils.reduce_tensor(total_acc_mid, "mean"), ) model.train() return ( total_loss, total_acc_corr, total_acc_incorr, total_precision, total_recall, total_acc_mid, ) def dump_train_vis(res, data, step, config): # batch matching p = res["p"][:, :-1, :-1] score, index1 = torch.max(p, dim=-1) _, index2 = torch.max(p, dim=-2) mask_th = score > 0.2 mask_mc = index2.gather(index=index1, dim=1) == torch.arange(len(p[0])).cuda()[None] mask_p = mask_th & mask_mc # B*N corr1, corr2 = data["x1"], data["x2"].gather( index=index1[:, :, None].expand(-1, -1, 2), dim=1 ) corr1_kpt, corr2_kpt = data["kpt1"], data["kpt2"].gather( index=index1[:, :, None].expand(-1, -1, 2), dim=1 ) epi_dis = batch_episym(corr1, corr2, data["e_gt"]) mask_inlier = epi_dis < config.inlier_th # B*N # dump vis for cur_mask_p, cur_mask_inlier, cur_corr1, cur_corr2, img_path1, img_path2 in zip( mask_p, mask_inlier, corr1_kpt, corr2_kpt, data["img_path1"], data["img_path2"] ): img1, img2 = cv2.imread(img_path1), cv2.imread(img_path2) dis_play = evaluation_utils.draw_match( img1, img2, cur_corr1[cur_mask_p].cpu().numpy(), cur_corr2[cur_mask_p].cpu().numpy(), inlier=cur_mask_inlier, ) base_name_seq = os.path.join( img_path1.split("/")[-1] + "_" + img_path2.split("/")[-1] + "_" + img_path1.split("/")[-2] ) save_path = os.path.join( config.train_vis_folder, "train_vis", config.log_base, str(step), base_name_seq + ".png", ) cv2.imwrite(save_path, dis_play)