File size: 6,326 Bytes
3d1f2c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import sys
import json
import glob
from typing import final
import yaml
import torch
import zipfile
import argparse
import warnings
import numpy as np
import torchvision.transforms as T
import torchvision.transforms.functional as f
from tqdm import tqdm
from PIL import Image
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from model.cls_hrnet import get_cls_net
from model.cls_hrnet_l import get_cls_net as get_cls_net_l
from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, \
complete_keypoints, coords_to_dict
from utils.utils_keypoints import KeypointsDB
from utils.utils_lines import LineKeypointsDB
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--cfg", type=str, required=True,
help="Path to the (kp model) configuration file")
parser.add_argument("--cfg_l", type=str, required=True,
help="Path to the (line model) configuration file")
parser.add_argument("--root_dir", type=str, required=True,
help="Root directory")
parser.add_argument("--split", type=str, required=True,
help="Dataset split")
parser.add_argument("--save_dir", type=str, required=True,
help="Root directory")
parser.add_argument("--weights_kp", type=str, required=True,
help="Model (keypoints) weigths to use")
parser.add_argument("--weights_line", type=str, required=True,
help="Model (lines) weigths to use")
parser.add_argument("--cuda", type=str, default="cuda:0",
help="CUDA device index (default: 'cuda:0')")
parser.add_argument("--kp_th", type=float, default="0.1")
parser.add_argument("--line_th", type=float, default="0.1")
parser.add_argument("--batch", type=int, default=1, help="Batch size")
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
parser.add_argument('--use_gt', action='store_true', help='Use ground truth (default: False)')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
files = glob.glob(os.path.join(args.root_dir + args.split, "*.jpg"))
if args.use_gt:
zip_name_pred = args.save_dir + args.split + '_gt.zip'
else:
zip_name_pred = args.save_dir + args.split + '_pred.zip'
if args.use_gt:
device = torch.device(args.cuda if torch.cuda.is_available() else 'cpu')
with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
for file in tqdm(files, desc="Processing Images"):
image = Image.open(file)
w, h = image.size
homography_file = args.root_dir + args.split + '/' + \
file.split('/')[-1].split('.')[0] + '.homographyMatrix'
json_path = file.split('.')[0] + ".json"
f = open(json_path)
data = json.load(f)
kp_db = KeypointsDB(data, image)
line_db = LineKeypointsDB(data, image)
heatmaps, _ = kp_db.get_tensor_w_mask()
heatmaps = torch.tensor(heatmaps).unsqueeze(0)
heatmaps_l = line_db.get_tensor()
heatmaps_l = torch.tensor(heatmaps_l).unsqueeze(0)
kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
kp_dict = coords_to_dict(kp_coords, threshold=0.1)
lines_dict = coords_to_dict(line_coords, threshold=0.1)
final_kp_dict = complete_keypoints(kp_dict, lines_dict, w=w, h=h, normalize=True)
final_dict = {'kp': final_kp_dict, 'lines': lines_dict}
json_data = json.dumps(final_dict)
zip_file.writestr(f"{file.split('/')[-1].split('.')[0]}.json", json_data)
else:
device = torch.device(args.cuda if torch.cuda.is_available() else 'cpu')
cfg = yaml.safe_load(open(args.cfg, 'r'))
cfg_l = yaml.safe_load(open(args.cfg_l, 'r'))
loaded_state = torch.load(args.weights_kp, map_location=device)
model = get_cls_net(cfg)
model.load_state_dict(loaded_state)
model.to(device)
model.eval()
loaded_state_l = torch.load(args.weights_line, map_location=device)
model_l = get_cls_net_l(cfg_l)
model_l.load_state_dict(loaded_state_l)
model_l.to(device)
model_l.eval()
transform = T.Resize((540, 960))
with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
for file in tqdm(files, desc="Processing Images"):
image = Image.open(file)
image = f.to_tensor(image).float().to(device).unsqueeze(0)
image = image if image.size()[-1] == 960 else transform(image)
b, c, h, w = image.size()
homography_file = args.root_dir + args.split + '/' + \
file.split('/')[-1].split('.')[0] + '.homographyMatrix'
with torch.no_grad():
heatmaps = model(image)
heatmaps_l = model_l(image)
kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
kp_dict = coords_to_dict(kp_coords, threshold=args.kp_th, ground_plane_only=True)
lines_dict = coords_to_dict(line_coords, threshold=args.line_th, ground_plane_only=True)
final_kp_dict, final_lines_dict = complete_keypoints(kp_dict[0], lines_dict[0],
w=w, h=h, normalize=True)
final_dict = {'kp': final_kp_dict, 'lines': final_lines_dict}
json_data = json.dumps(final_dict)
zip_file.writestr(f"{file.split('/')[-1].split('.')[0]}.json", json_data)
|