File size: 5,441 Bytes
5ab0373
 
 
 
 
 
 
 
 
 
c9d11b2
5ab0373
 
c9d11b2
5ab0373
 
 
 
 
 
 
 
 
 
 
 
 
 
d57b89b
e8f4d7e
 
 
 
 
 
 
 
 
 
d57b89b
5ab0373
752c2e9
d57b89b
 
 
 
 
 
 
 
 
 
e8f4d7e
 
 
d57b89b
 
c9d11b2
 
 
 
8b2b08b
 
 
 
c9d11b2
 
 
 
 
5ab0373
e8f4d7e
5ab0373
 
 
d57b89b
5ab0373
 
d57b89b
5ab0373
 
 
 
d57b89b
5ab0373
 
 
d57b89b
5ab0373
2482ba4
c9d11b2
2482ba4
 
 
 
 
5ab0373
c9d11b2
5ab0373
c9d11b2
5ab0373
 
 
 
 
 
2482ba4
 
 
 
 
 
5ab0373
2482ba4
 
 
 
5ab0373
 
 
 
 
 
 
 
d57b89b
5ab0373
 
 
 
 
 
 
d57b89b
 
5ab0373
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import project_path
import argparse
from datetime import datetime
import torch
import os
from dataloader import create_dataloader_frames_only
from aris import create_manual_marking, create_metadata_dictionary, prep_for_mm
from inference import setup_model, do_suppression, do_detection, do_tracking, json_dump_round_float
from visualizer import generate_video_batches
import json
from tqdm import tqdm


def main(args, config={}, verbose=True):
    """
    Main processing task to be run in gradio
        - Writes aris frames to dirname(filepath)/frames/{i}.jpg
        - Writes json output to dirname(filepath)/{filename}_results.json
        - Writes manual marking to dirname(filepath)/{filename}_marking.txt
        - Writes video output to dirname(filepath)/{filename}_results.mp4
        - Zips all results to dirname(filepath)/{filename}_results.zip
    Args:
        filepath (str): path to aris file
        
    TODO: Separate into subtasks in different queues; have a GPU-only queue.
    """
    print("In task...")
    print("Cuda available in task?", torch.cuda.is_available())

    # setup config
    if "conf_threshold" not in config: config['conf_threshold'] = 0.3
    if "nms_iou" not in config: config['nms_iou'] = 0.3
    if "min_length" not in config: config['min_length'] = 0.3
    if "max_age" not in config: config['max_age'] = 20
    if "iou_threshold" not in config: config['iou_threshold'] = 0.01
    if "min_hits" not in config: config['min_hits'] = 11

    print(config)

    dirname = args.frames
    
    locations = ["kenai-val"]
    for loc in locations:

        in_loc_dir = os.path.join(dirname, loc)
        out_dir = os.path.join(args.output, loc, "tracker", "data")
        metadata_path = os.path.join(args.metadata, loc + ".json")
        os.makedirs(out_dir, exist_ok=True)
        print(in_loc_dir)
        print(out_dir)
        print(metadata_path)

        # run detection + tracking
        model, device = setup_model(args.weights)

        seq_list = os.listdir(in_loc_dir)
        idx = 1
        with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar:
            for seq in seq_list:
                pbar.update(1)
                pbar.set_description("Processing " + seq)
                if verbose:
                    print(" ")
                    print("(" + str(idx) + "/" + str(len(seq_list)) + ") " + seq)
                    print(" ")
                idx += 1
                in_seq_dir = os.path.join(in_loc_dir, seq)
                infer_seq(in_seq_dir, out_dir, config, seq, model, device, metadata_path, verbose)

def infer_seq(in_dir, out_dir, config, seq_name, model, device, metadata_path, verbose):
    
    #progress_log = lambda p, m: 0

    image_meter_width = -1
    image_meter_height = -1
    with open(metadata_path, 'r') as f:
        json_object = json.loads(f.read())
        for seq in json_object:
            if seq['clip_name'] == seq_name:
                image_meter_width = seq['x_meter_stop'] - seq['x_meter_start']
                image_meter_height = seq['y_meter_stop'] - seq['y_meter_start']

    if (image_meter_height == -1):
        print("No metadata found for file " + seq_name)
        return

    # create dataloader
    dataloader = create_dataloader_frames_only(in_dir)

    try:
        inference, width, height = do_detection(dataloader, model, device, verbose=verbose)
    except:
        print("Error in " + seq_name)
        with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
            f.write("ERROR")
        return

    all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'], verbose=verbose)

    results = do_tracking(all_preds, image_meter_width, image_meter_height, min_length=config['min_length'], max_age=config['max_age'], iou_thres=config['iou_threshold'], min_hits=config['min_hits'], verbose=verbose)

    mot_rows = []
    for frame in results['frames']:
        for fish in frame['fish']:
            bbox = fish['bbox']
            row = []
            right = bbox[0]*real_width
            top = bbox[1]*real_height
            w = bbox[2]*real_width - bbox[0]*real_width
            h = bbox[3]*real_height - bbox[1]*real_height

            row.append(str(frame['frame_num'] + 1))
            row.append(str(fish['fish_id'] + 1))
            row.append(str(int(right)))
            row.append(str(int(top)))
            row.append(str(int(w)))
            row.append(str(int(h)))
            row.append("-1")
            row.append("-1")
            row.append("-1")
            row.append("-1")
            mot_rows.append(",".join(row))

    mot_text = "\n".join(mot_rows)

    with open(os.path.join(out_dir, seq_name + ".txt"), 'w') as f:
        f.write(mot_text)

    return

def argument_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--frames", required=True, help="Path to frame directory. Required.")
    parser.add_argument("--metadata", required=True, help="Path to metadata directory. Required.")
    parser.add_argument("--output", required=True, help="Path to output directory. Required.")
    parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt")
    return parser

if __name__ == "__main__":
    args = argument_parser().parse_args()
    main(args)