File size: 5,597 Bytes
3d1f2c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import sys
import json
import glob
import yaml
import torch
import zipfile
import argparse
import warnings
import numpy as np
import torchvision.transforms as T
import torchvision.transforms.functional as f
from tqdm import tqdm
from PIL import Image
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from model.cls_hrnet import get_cls_net
from model.cls_hrnet_l import get_cls_net as get_cls_net_l
from utils.utils_heatmap import get_keypoints_from_heatmap_batch_maxpool, get_keypoints_from_heatmap_batch_maxpool_l, \
complete_keypoints, coords_to_dict
from utils.utils_keypoints import KeypointsDB
from utils.utils_lines import LineKeypointsDB
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--cfg", type=str, required=True,
help="Path to the (kp model) configuration file")
parser.add_argument("--cfg_l", type=str, required=True,
help="Path to the (line model) configuration file")
parser.add_argument("--root_dir", type=str, required=True,
help="Root directory")
parser.add_argument("--split", type=str, required=True,
help="Dataset split")
parser.add_argument("--save_dir", type=str, required=True,
help="Root directory")
parser.add_argument("--weights_kp", type=str, required=True,
help="Model (keypoints) weigths to use")
parser.add_argument("--weights_line", type=str, required=True,
help="Model (lines) weigths to use")
parser.add_argument("--cuda", type=str, default="cuda:0",
help="CUDA device index (default: 'cuda:0')")
parser.add_argument("--kp_th", type=float, default="0.1")
parser.add_argument("--line_th", type=float, default="0.1")
parser.add_argument("--batch", type=int, default=1, help="Batch size")
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
args = parser.parse_args()
return args
def get_files(file_paths):
jpg_files = []
for file_path in file_paths:
directory_path = os.path.join(os.path.join(args.root_dir, "Dataset/80_95"), file_path)
if os.path.exists(directory_path):
files = os.listdir(directory_path)
jpg_files.extend([os.path.join(directory_path, file) for file in files if file.endswith('.jpg')])
jpg_files = sorted(jpg_files)
return jpg_files
def get_homographies(file_paths):
npy_files = []
for file_path in file_paths:
directory_path = os.path.join(os.path.join(args.root_dir, "Annotations/80_95"), file_path)
if os.path.exists(directory_path):
files = os.listdir(directory_path)
npy_files.extend([os.path.join(directory_path, file) for file in files if file.endswith('.npy')])
npy_files = sorted(npy_files)
return npy_files
def make_file_name(file):
file = "TS-WorldCup/" + file.split("TS-WorldCup/")[-1]
splits = file.split('/')
side = splits[3]
match = splits[4]
image = splits[5]
frame = 'IMG_' + image.split('.')[0].split('_')[-1]
return side + '-' + match + '-' + frame
if __name__ == "__main__":
args = parse_args()
with open(args.root_dir + args.split + '.txt', 'r') as file:
# Read lines from the file and remove trailing newline characters
seqs = [line.strip() for line in file.readlines()]
files = get_files(seqs)
homographies = get_homographies(seqs)
zip_name_pred = args.save_dir + args.split + '_pred.zip'
device = torch.device(args.cuda if torch.cuda.is_available() else 'cpu')
cfg = yaml.safe_load(open(args.cfg, 'r'))
cfg_l = yaml.safe_load(open(args.cfg_l, 'r'))
loaded_state = torch.load(args.weights_kp, map_location=device)
model = get_cls_net(cfg)
model.load_state_dict(loaded_state)
model.to(device)
model.eval()
loaded_state_l = torch.load(args.weights_line, map_location=device)
model_l = get_cls_net_l(cfg_l)
model_l.load_state_dict(loaded_state_l)
model_l.to(device)
model_l.eval()
transform = T.Resize((540, 960))
with zipfile.ZipFile(zip_name_pred, 'w') as zip_file:
for count in tqdm(range(len(files)), desc="Processing Images"):
image = Image.open(files[count])
image = f.to_tensor(image).float().to(device).unsqueeze(0)
image = image if image.size()[-1] == 960 else transform(image)
b, c, h, w = image.size()
with torch.no_grad():
heatmaps = model(image)
heatmaps_l = model_l(image)
kp_coords = get_keypoints_from_heatmap_batch_maxpool(heatmaps[:,:-1,:,:])
line_coords = get_keypoints_from_heatmap_batch_maxpool_l(heatmaps_l[:,:-1,:,:])
kp_dict = coords_to_dict(kp_coords, threshold=args.kp_th, ground_plane_only=True)
lines_dict = coords_to_dict(line_coords, threshold=args.line_th, ground_plane_only=True)
final_kp_dict, final_lines_dict = complete_keypoints(kp_dict[0], lines_dict[0],
w=w, h=h, normalize=True)
final_dict = {'kp': final_kp_dict, 'lines': final_lines_dict}
json_data = json.dumps(final_dict)
zip_file.writestr(f"{make_file_name(files[count])}.json", json_data)
|