Spaces:
Runtime error
Runtime error
File size: 4,621 Bytes
128e4f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import argparse
import torch
import os
import json
from tqdm import tqdm
import project_subpath
from backend.InferenceConfig import InferenceConfig
from backend.dataloader import create_dataloader_frames_only
from backend.inference import do_full_tracking, setup_model, do_detection
def main(args, config=InferenceConfig(), verbose=True):
"""
Perform inference on a directory of frames and saves the tracking json result
Args:
frames (str): Path to frame directory. Required.
metadata (str): Path to metadata directory. Required.
output (str): Path to output directory. Required.
weights (str): Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt
"""
print("In task...")
print("Cuda available in task?", torch.cuda.is_available())
print("Config:", config.to_dict())
dirname = args.frames
loc = args.location
in_loc_dir = os.path.join(dirname, loc)
out_dir = os.path.join(args.output, loc, "tracker", "data")
metadata_path = os.path.join(args.metadata, loc + ".json")
os.makedirs(out_dir, exist_ok=True)
print(in_loc_dir)
print(out_dir)
print(metadata_path)
# run detection + tracking
model, device = setup_model(args.weights)
seq_list = os.listdir(in_loc_dir)
idx = 1
with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar:
for seq in seq_list:
pbar.update(1)
pbar.set_description("Processing " + seq)
if verbose:
print(" ")
print("(" + str(idx) + "/" + str(len(seq_list)) + ") " + seq)
print(" ")
idx += 1
in_seq_dir = os.path.join(in_loc_dir, seq)
infer_seq(in_seq_dir, out_dir, config, seq, model, device, metadata_path, verbose)
def infer_seq(in_dir, out_dir, config, seq_name, model, device, metadata_path, verbose):
#progress_log = lambda p, m: 0
image_meter_width = -1
image_meter_height = -1
with open(metadata_path, 'r') as f:
json_object = json.loads(f.read())
for seq in json_object:
if seq['clip_name'] == seq_name:
image_meter_width = seq['x_meter_stop'] - seq['x_meter_start']
image_meter_height = seq['y_meter_stop'] - seq['y_meter_start']
if (image_meter_height == -1):
print("No metadata found for file " + seq_name)
return
# create dataloader
dataloader = create_dataloader_frames_only(in_dir)
try:
inference, image_shapes, width, height = do_detection(dataloader, model, device, verbose=verbose)
except:
print("Error in " + seq_name)
with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
f.write("ERROR")
return
real_width = image_shapes[0][0][0][1]
real_height = image_shapes[0][0][0][0]
results = do_full_tracking(inference, image_shapes, image_meter_width, image_meter_height, width, height, config=config, gp=None, verbose=verbose)
mot_rows = []
for frame in results['frames']:
for fish in frame['fish']:
bbox = fish['bbox']
row = []
right = bbox[0]*real_width
top = bbox[1]*real_height
w = bbox[2]*real_width - bbox[0]*real_width
h = bbox[3]*real_height - bbox[1]*real_height
row.append(str(frame['frame_num'] + 1))
row.append(str(fish['fish_id'] + 1))
row.append(str(int(right)))
row.append(str(int(top)))
row.append(str(int(w)))
row.append(str(int(h)))
row.append("-1")
row.append("-1")
row.append("-1")
row.append("-1")
mot_rows.append(",".join(row))
mot_text = "\n".join(mot_rows)
with open(os.path.join(out_dir, seq_name + ".txt"), 'w') as f:
f.write(mot_text)
return
def argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--frames", required=True, help="Path to frame directory. Required.")
parser.add_argument("--location", required=True, help="Name of location dir. Required.")
parser.add_argument("--metadata", required=True, help="Path to metadata directory. Required.")
parser.add_argument("--output", required=True, help="Path to output directory. Required.")
parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt")
return parser
if __name__ == "__main__":
args = argument_parser().parse_args()
main(args) |