oskarastrom commited on
Commit
2482ba4
1 Parent(s): d57b89b

Updated inference

Browse files
Files changed (2) hide show
  1. inference.py +6 -5
  2. scripts/infer_frames.py +19 -9
inference.py CHANGED
@@ -24,8 +24,8 @@ WEIGHTS = 'models/v5m_896_300best.pt'
24
  # will need to configure these based on GPU hardware
25
  BATCH_SIZE = 32
26
 
27
- conf_thres = 0.001 # detection
28
- iou_thres = 0.1 # NMS IOU
29
  min_length = 0.3 # minimum fish length, in meters
30
  ###
31
 
@@ -72,7 +72,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
72
  return
73
 
74
 
75
- all_preds = do_suppression(dataloader, inference, width, height, gp=gp)
76
 
77
  results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gp)
78
 
@@ -122,7 +122,6 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE):
122
  size = tuple(img.shape)
123
  nb, _, height, width = size # batch size, channels, height, width
124
 
125
- print(nb, _, height, width)
126
  # Run model & NMS
127
  with torch.no_grad():
128
  inf_out, _ = model(img, augment=False)
@@ -166,6 +165,8 @@ def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BAT
166
  # confidence score currently not used by tracker; set to 1.0
167
  boxes = None
168
  if box.shape[0]:
 
 
169
  do_norm = partial(norm, w=shapes[si][0][1], h=shapes[si][0][0])
170
  normed = list((map(do_norm, box[:, :4].tolist())))
171
  boxes = np.stack([ [*bb, conf] for bb, conf in zip(normed, confs) ])
@@ -174,7 +175,7 @@ def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BAT
174
 
175
  pbar.update(1*batch_size)
176
 
177
- return all_preds
178
 
179
  def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None):
180
 
 
24
  # will need to configure these based on GPU hardware
25
  BATCH_SIZE = 32
26
 
27
+ conf_thres = 0.3 # detection
28
+ iou_thres = 0.3 # NMS IOU
29
  min_length = 0.3 # minimum fish length, in meters
30
  ###
31
 
 
72
  return
73
 
74
 
75
+ all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, gp=gp)
76
 
77
  results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gp)
78
 
 
122
  size = tuple(img.shape)
123
  nb, _, height, width = size # batch size, channels, height, width
124
 
 
125
  # Run model & NMS
126
  with torch.no_grad():
127
  inf_out, _ = model(img, augment=False)
 
165
  # confidence score currently not used by tracker; set to 1.0
166
  boxes = None
167
  if box.shape[0]:
168
+ real_width = shapes[si][0][1]
169
+ real_height = shapes[si][0][0]
170
  do_norm = partial(norm, w=shapes[si][0][1], h=shapes[si][0][0])
171
  normed = list((map(do_norm, box[:, :4].tolist())))
172
  boxes = np.stack([ [*bb, conf] for bb, conf in zip(normed, confs) ])
 
175
 
176
  pbar.update(1*batch_size)
177
 
178
+ return all_preds, real_width, real_height
179
 
180
  def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None):
181
 
scripts/infer_frames.py CHANGED
@@ -8,7 +8,6 @@ from aris import create_manual_marking, create_metadata_dictionary, prep_for_mm
8
  from inference import setup_model, do_suppression, do_detection, do_tracking, json_dump_round_float
9
  from visualizer import generate_video_batches
10
  import json
11
- from tracker import Tracker
12
 
13
 
14
  def main(args):
@@ -29,7 +28,7 @@ def main(args):
29
 
30
  dirname = args.frames
31
 
32
- locations = ["kenai-val"]
33
  for loc in locations:
34
 
35
  in_loc_dir = os.path.join(dirname, loc)
@@ -73,9 +72,15 @@ def infer_seq(in_dir, out_dir, seq_name, weights, metadata_path):
73
  # run detection + tracking
74
  model, device = setup_model(weights)
75
 
76
- inference, width, height = do_detection(dataloader, model, device, gp=gradio_progress)
 
 
 
 
 
 
77
 
78
- all_preds = do_suppression(dataloader, inference, width, height, gp=gradio_progress)
79
 
80
  results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gradio_progress)
81
 
@@ -84,12 +89,17 @@ def infer_seq(in_dir, out_dir, seq_name, weights, metadata_path):
84
  for fish in frame['fish']:
85
  bbox = fish['bbox']
86
  row = []
87
- row.append(str(frame['frame_num']))
 
 
 
 
 
88
  row.append(str(fish['fish_id'] + 1))
89
- row.append(str(int(bbox[0]*width)))
90
- row.append(str(int(bbox[1]*height)))
91
- row.append(str(int(bbox[2]*width)))
92
- row.append(str(int(bbox[3]*height)))
93
  row.append("-1")
94
  row.append("-1")
95
  row.append("-1")
 
8
  from inference import setup_model, do_suppression, do_detection, do_tracking, json_dump_round_float
9
  from visualizer import generate_video_batches
10
  import json
 
11
 
12
 
13
  def main(args):
 
28
 
29
  dirname = args.frames
30
 
31
+ locations = ["test"]
32
  for loc in locations:
33
 
34
  in_loc_dir = os.path.join(dirname, loc)
 
72
  # run detection + tracking
73
  model, device = setup_model(weights)
74
 
75
+ try:
76
+ inference, width, height = do_detection(dataloader, model, device, gp=gradio_progress)
77
+ except:
78
+ print("Error in " + seq_name)
79
+ with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
80
+ f.write("ERROR")
81
+ return
82
 
83
+ all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, gp=gradio_progress)
84
 
85
  results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gradio_progress)
86
 
 
89
  for fish in frame['fish']:
90
  bbox = fish['bbox']
91
  row = []
92
+ right = bbox[0]*real_width
93
+ top = bbox[1]*real_height
94
+ w = bbox[2]*real_width - bbox[0]*real_width
95
+ h = bbox[3]*real_height - bbox[1]*real_height
96
+
97
+ row.append(str(frame['frame_num'] + 1))
98
  row.append(str(fish['fish_id'] + 1))
99
+ row.append(str(int(right)))
100
+ row.append(str(int(top)))
101
+ row.append(str(int(w)))
102
+ row.append(str(int(h)))
103
  row.append("-1")
104
  row.append("-1")
105
  row.append("-1")