PnLCalib / scripts /inference_sn.py
2nzi's picture
Upload 63 files
3d1f2c9
import os
import sys
import json
import glob
import yaml
import torch
import zipfile
import argparse
import warnings
import numpy as np
import torchvision.transforms as T
import torchvision.transforms.functional as f
from tqdm import tqdm
from PIL import Image
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from model.cls_hrnet import get_cls_net
from model.cls_hrnet_l import get_cls_net as get_cls_net_l
from utils.utils_keypoints import KeypointsDB
from utils.utils_lines import LineKeypointsDB
from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, \
coords_to_dict, complete_keypoints
from utils.utils_calib import FramebyFrameCalib
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=np.RankWarning)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--cfg", type=str, required=True,
help="Path to the (kp model) configuration file")
parser.add_argument("--cfg_l", type=str, required=True,
help="Path to the (line model) configuration file")
parser.add_argument("--root_dir", type=str, required=True,
help="Root directory")
parser.add_argument("--split", type=str, required=True,
help="Dataset split")
parser.add_argument("--save_dir", type=str, required=True,
help="Saving file path")
parser.add_argument("--weights_kp", type=str, required=True,
help="Model (keypoints) weigths to use")
parser.add_argument("--weights_line", type=str, required=True,
help="Model (lines) weigths to use")
parser.add_argument("--cuda", type=str, default="cuda:0",
help="CUDA device index (default: 'cuda:0')")
parser.add_argument("--kp_th", type=float, default="0.1")
parser.add_argument("--line_th", type=float, default="0.1")
parser.add_argument("--max_reproj_err", type=float, default="50")
parser.add_argument("--main_cam_only", action='store_true')
parser.add_argument('--use_gt', action='store_true', help='Use ground truth annotations (default: False)')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
files = glob.glob(os.path.join(args.root_dir + args.split, "*.jpg"))
if args.main_cam_only:
cam_info = json.load(open(args.root_dir + args.split + '/match_info_cam_gt.json'))
files = [file for file in files if file.split('/')[-1] in cam_info.keys()]
files = [file for file in files if cam_info[file.split('/')[-1]]['camera'] == 'Main camera center']
# files = [file for file in files if int(match_info[file.split('/')[-1]]['ms_time']) == \
# int(match_info[file.split('/')[-1]]['replay_time'])]
if args.main_cam_only:
zip_name = args.save_dir + args.split + '_main.zip'
else:
zip_name = args.save_dir + args.split + '.zip'
if args.use_gt:
if args.main_cam_only:
zip_name_pred = args.save_dir + args.split + '_main_gt.zip'
else:
zip_name_pred = args.save_dir + args.split + '_gt.zip'
else:
if args.main_cam_only:
zip_name_pred = args.save_dir + args.split + '_main_pred.zip'
else:
zip_name_pred = args.save_dir + args.split + '_pred.zip'
print(f"Saving results in {args.save_dir}")
print(f"file: {zip_name_pred}")
if args.use_gt:
transform = T.Resize((540, 960))
cam = FramebyFrameCalib(960, 540, denormalize=True)
with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
samples, complete = 0., 0.
for file in tqdm(files, desc="Processing Images"):
image = Image.open(file)
file_name = file.split('/')[-1].split('.')[0]
samples += 1
json_path = file.split('.')[0] + ".json"
f = open(json_path)
data = json.load(f)
kp_db = KeypointsDB(data, image)
line_db = LineKeypointsDB(data, image)
heatmaps, _ = kp_db.get_tensor_w_mask()
heatmaps = torch.tensor(heatmaps).unsqueeze(0)
heatmaps_l = line_db.get_tensor()
heatmaps_l = torch.tensor(heatmaps_l).unsqueeze(0)
kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:, :-1, :, :])
line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:, :-1, :, :])
kp_dict = coords_to_dict(kp_coords, threshold=0.01)
lines_dict = coords_to_dict(line_coords, threshold=0.01)
cam.update(kp_dict, lines_dict)
final_params_dict = cam.heuristic_voting()
# final_params_dict = cam.calibrate(5)
if final_params_dict:
complete += 1
cam_params = final_params_dict['cam_params']
print("heheheheheheh")
json_data = json.dumps(cam_params)
zip_file.writestr(f"camera_{file_name}.json", json_data)
else:
device = torch.device(args.cuda if torch.cuda.is_available() else 'cpu')
cfg = yaml.safe_load(open(args.cfg, 'r'))
cfg_l = yaml.safe_load(open(args.cfg_l, 'r'))
loaded_state = torch.load(args.weights_kp, map_location=device)
model = get_cls_net(cfg)
model.load_state_dict(loaded_state)
model.to(device)
model.eval()
loaded_state_l = torch.load(args.weights_line, map_location=device)
model_l = get_cls_net_l(cfg_l)
model_l.load_state_dict(loaded_state_l)
model_l.to(device)
model_l.eval()
transform = T.Resize((540, 960))
cam = FramebyFrameCalib(960, 540)
with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
samples, complete = 0., 0.
for file in tqdm(files, desc="Processing Images"):
image = Image.open(file)
file_name = file.split('/')[-1].split('.')[0]
samples += 1
with torch.no_grad():
image = f.to_tensor(image).float().to(device).unsqueeze(0)
image = image if image.size()[-1] == 960 else transform(image)
b, c, h, w = image.size()
heatmaps = model(image)
heatmaps_l = model_l(image)
kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:, :-1, :, :])
line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:, :-1, :, :])
kp_dict = coords_to_dict(kp_coords, threshold=args.kp_th)
lines_dict = coords_to_dict(line_coords, threshold=args.line_th)
kp_dict, lines_dict = complete_keypoints(kp_dict[0], lines_dict[0], w=w, h=h)
cam.update(kp_dict, lines_dict)
final_params_dict = cam.heuristic_voting(refine_lines=True)
if final_params_dict:
if final_params_dict['rep_err'] <= args.max_reproj_err:
complete += 1
cam_params = final_params_dict['cam_params']
json_data = json.dumps(cam_params)
zip_file.writestr(f"camera_{file_name}.json", json_data)
with zipfile.ZipFile(zip_name, 'w') as zip_file:
for file in tqdm(files, desc="Processing Images"):
file_name = file.split('/')[-1].split('.')[0]
data = json.load(open(file.split('.')[0] + ".json"))
json_data = json.dumps(data)
zip_file.writestr(f"{file_name}.json", json_data)
print(f'Completed {complete} / {samples}')