File size: 4,621 Bytes
128e4f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

import argparse
import torch
import os
import json
from tqdm import tqdm

import project_subpath
from backend.InferenceConfig import InferenceConfig
from backend.dataloader import create_dataloader_frames_only
from backend.inference import do_full_tracking, setup_model, do_detection


def main(args, config=InferenceConfig(), verbose=True):
    """
    Perform inference on a directory of frames and saves the tracking json result
    Args:
        frames (str): Path to frame directory. Required.
        metadata (str): Path to metadata directory. Required.
        output (str): Path to output directory. Required.
        weights (str): Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt
    """
    
    print("In task...")
    print("Cuda available in task?", torch.cuda.is_available())

    print("Config:", config.to_dict())

    dirname = args.frames
    loc = args.location

    in_loc_dir = os.path.join(dirname, loc)
    out_dir = os.path.join(args.output, loc, "tracker", "data")
    metadata_path = os.path.join(args.metadata, loc + ".json")
    os.makedirs(out_dir, exist_ok=True)
    print(in_loc_dir)
    print(out_dir)
    print(metadata_path)

    # run detection + tracking
    model, device = setup_model(args.weights)

    seq_list = os.listdir(in_loc_dir)
    idx = 1
    with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar:
        for seq in seq_list:
            pbar.update(1)
            pbar.set_description("Processing " + seq)
            if verbose:
                print(" ")
                print("(" + str(idx) + "/" + str(len(seq_list)) + ") " + seq)
                print(" ")
            idx += 1
            in_seq_dir = os.path.join(in_loc_dir, seq)
            infer_seq(in_seq_dir, out_dir, config, seq, model, device, metadata_path, verbose)

def infer_seq(in_dir, out_dir, config, seq_name, model, device, metadata_path, verbose):
    
    #progress_log = lambda p, m: 0

    image_meter_width = -1
    image_meter_height = -1
    with open(metadata_path, 'r') as f:
        json_object = json.loads(f.read())
        for seq in json_object:
            if seq['clip_name'] == seq_name:
                image_meter_width = seq['x_meter_stop'] - seq['x_meter_start']
                image_meter_height = seq['y_meter_stop'] - seq['y_meter_start']

    if (image_meter_height == -1):
        print("No metadata found for file " + seq_name)
        return

    # create dataloader
    dataloader = create_dataloader_frames_only(in_dir)

    try:
        inference, image_shapes, width, height = do_detection(dataloader, model, device, verbose=verbose)
    except:
        print("Error in " + seq_name)
        with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
            f.write("ERROR")
        return

    real_width = image_shapes[0][0][0][1]
    real_height = image_shapes[0][0][0][0]

    results = do_full_tracking(inference, image_shapes, image_meter_width, image_meter_height, width, height, config=config, gp=None, verbose=verbose)
    
    mot_rows = []
    for frame in results['frames']:
        for fish in frame['fish']:
            bbox = fish['bbox']
            row = []
            right = bbox[0]*real_width
            top = bbox[1]*real_height
            w = bbox[2]*real_width - bbox[0]*real_width
            h = bbox[3]*real_height - bbox[1]*real_height

            row.append(str(frame['frame_num'] + 1))
            row.append(str(fish['fish_id'] + 1))
            row.append(str(int(right)))
            row.append(str(int(top)))
            row.append(str(int(w)))
            row.append(str(int(h)))
            row.append("-1")
            row.append("-1")
            row.append("-1")
            row.append("-1")
            mot_rows.append(",".join(row))

    mot_text = "\n".join(mot_rows)

    with open(os.path.join(out_dir, seq_name + ".txt"), 'w') as f:
        f.write(mot_text)

    return

def argument_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--frames", required=True, help="Path to frame directory. Required.")
    parser.add_argument("--location", required=True, help="Name of location dir. Required.")
    parser.add_argument("--metadata", required=True, help="Path to metadata directory. Required.")
    parser.add_argument("--output", required=True, help="Path to output directory. Required.")
    parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt")
    return parser

if __name__ == "__main__":
    args = argument_parser().parse_args()
    main(args)