fisheye-experimental / inference.py
oskarastrom's picture
Bug fixes
d9c7dce
raw
history blame
No virus
20.2 kB
import project_path
import torch
from tqdm import tqdm
from functools import partial
import numpy as np
import json
import time
from unittest.mock import patch
import math
# assumes yolov5 on sys.path
from lib.yolov5.models.experimental import attempt_load
from lib.yolov5.utils.torch_utils import select_device
from lib.yolov5.utils.general import clip_boxes, scale_boxes, xywh2xyxy
from lib.yolov5.utils.metrics import box_iou
import torch
import torchvision
from InferenceConfig import InferenceConfig, TrackerType
from lib.fish_eye.tracker import Tracker
from lib.fish_eye.bytetrack import Associate
### Configuration options
WEIGHTS = 'models/v5m_896_300best.pt'
# will need to configure these based on GPU hardware
BATCH_SIZE = 32
CONF_THRES = 0.05 # detection
NMS_IOU = 0.2 # NMS IOU
MAX_AGE = 14 # time until missing fish get's new id
MIN_HITS = 16 # minimum number of frames with a specific fish for it to count
MIN_LENGTH = 0.3 # minimum fish length, in meters
IOU_THRES = 0.01 # IOU threshold for tracking
MIN_TRAVEL = -1 # Minimum distance a track has to travel
###
def norm(bbox, w, h):
"""
Normalize a bounding box.
Args:
bbox: list of length 4. Can be [x,y,w,h] or [x0,y0,x1,y1]
w: image width
h: image height
"""
bb = bbox.copy()
bb[0] /= w
bb[1] /= h
bb[2] /= w
bb[3] /= h
return bb
def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, config=InferenceConfig()):
# Set up model
model, device = setup_model(config.weights)
# Detect boxes in frames
inference, image_shapes, width, height = do_detection(dataloader, model, device, gp=gp)
if config.associative_tracker == TrackerType.BYTETRACK:
# Find low confidence detections
low_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp)
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
# Find high confidence detections
high_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp)
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
# Perform associative tracking (ByteTrack)
results = do_associative_tracking(
low_preds, high_preds, image_meter_width, image_meter_height,
reverse=False, min_length=config.min_length, min_travel=config.min_travel,
max_age=config.max_age, min_hits=config.min_hits,
gp=gp)
else:
# Find confident detections
outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp)
if config.associative_tracker == TrackerType.CONF_BOOST:
# Boost confidence based on found confident detections
do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
# Find confident detections from boosted list
outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp)
# Format confident detections
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
# Perform SORT tracking
results = do_tracking(
all_preds, image_meter_width, image_meter_height,
min_length=config.min_length, min_travel=config.min_travel,
max_age=config.max_age, min_hits=config.min_hits,
gp=gp)
return results
def setup_model(weights_fp=WEIGHTS, imgsz=896, batch_size=32):
if torch.cuda.is_available():
device = select_device('0', batch_size=batch_size)
else:
print("CUDA not available. Using CPU inference.")
device = select_device('cpu', batch_size=batch_size)
# Setup model for inference
model = attempt_load(weights_fp, device=device)
half = device.type != 'cpu' # half precision only supported on CUDA
if half:
model.half()
model.eval()
# Create dataloader for batched inference
img = torch.zeros((1, 3, imgsz, imgsz), device=device)
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
return model, device
def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE, verbose=True):
"""
Args:
frames_dir: a directory containing frames to be evaluated
image_meter_width: the width of each image, in meters (used for fish length calculation)
gp: a callback function which takes as input 1 parameter, (int) percent complete
prep_for_marking: re-index fish for manual marking output
"""
if (gp): gp(0, "Detection...")
inference = []
image_shapes = []
# Run detection
with tqdm(total=len(dataloader)*batch_size, desc="Running detection", ncols=0, disable=not verbose) as pbar:
for batch_i, (img, _, shapes) in enumerate(dataloader):
if gp: gp(batch_i / len(dataloader), pbar.__str__())
img = img.to(device, non_blocking=True)
img = img.half() if device.type != 'cpu' else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
size = tuple(img.shape)
nb, _, height, width = size # batch size, channels, height, width
# Run model & NMS
with torch.no_grad():
inf_out, _ = model(img, augment=False)
# Save shapes for resizing to original shape
batch_shape = []
for si, pred in enumerate(inf_out):
batch_shape.append((img[si].shape[1:], shapes[si]))
image_shapes.append(batch_shape)
inference.append(inf_out)
pbar.update(1*batch_size)
return inference, image_shapes, width, height
def do_suppression(inference, image_meter_width, image_pixel_width, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, max_length=1.5, verbose=True):
"""
Args:
frames_dir: a directory containing frames to be evaluated
image_meter_width: the width of each image, in meters (used for fish length calculation)
gp: a callback function which takes as input 1 parameter, (int) percent complete
prep_for_marking: re-index fish for manual marking output
"""
if (gp): gp(0, "Suppression...")
# keep predictions to feed them ordered into the Tracker
# TODO: how to deal with large files?
outputs = []
with tqdm(total=len(inference)*batch_size, desc="Running suppression", ncols=0, disable=not verbose) as pbar:
for batch_i, inf_out in enumerate(inference):
if gp: gp(batch_i / len(inference), pbar.__str__())
with torch.no_grad():
output = non_max_suppression(inf_out, image_meter_width, image_pixel_width, conf_thres=conf_thres, iou_thres=iou_thres, max_length=max_length)
outputs.append(output)
pbar.update(1*batch_size)
return outputs
def format_predictions(image_shapes, outputs, width, height, gp=None, batch_size=BATCH_SIZE, verbose=True):
"""
Args:
frames_dir: a directory containing frames to be evaluated
image_meter_width: the width of each image, in meters (used for fish length calculation)
gp: a callback function which takes as input 1 parameter, (int) percent complete
prep_for_marking: re-index fish for manual marking output
"""
if (gp): gp(0, "Formatting...")
# keep predictions to feed them ordered into the Tracker
# TODO: how to deal with large files?
all_preds = {}
with tqdm(total=len(image_shapes)*batch_size, desc="Running formatting", ncols=0, disable=not verbose) as pbar:
for batch_i, batch in enumerate(outputs):
if gp: gp(batch_i / len(image_shapes), pbar.__str__())
batch_shapes = image_shapes[batch_i]
# Format results
for si, pred in enumerate(batch):
(image_shape, original_shape) = batch_shapes[si]
# Clip boxes to image bounds and resize to input shape
clip_boxes(pred, (height, width))
box = pred[:, :4].clone() # xyxy
confs = pred[:, 4].clone().tolist()
scale_boxes(image_shape, box, original_shape[0], original_shape[1]) # to original shape
# get boxes into tracker input format - normalized xyxy with confidence score
# confidence score currently not used by tracker; set to 1.0
boxes = None
if box.shape[0]:
real_width = original_shape[0][1]
real_height = original_shape[0][0]
do_norm = partial(norm, w=original_shape[0][1], h=original_shape[0][0])
normed = list((map(do_norm, box[:, :4].tolist())))
boxes = np.stack([ [*bb, conf] for bb, conf in zip(normed, confs) ])
frame_num = (batch_i, si)
all_preds[frame_num] = boxes
pbar.update(1*batch_size)
return all_preds, real_width, real_height
# ---------------------------------------- TRACKING ------------------------------------------
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL, verbose=True):
"""
Perform SORT tracking based on formatted detections
"""
if (gp): gp(0, "Tracking...")
# Initialize tracker
clip_info = {
'start_frame': 0,
'end_frame': len(all_preds),
'image_meter_width': image_meter_width,
'image_meter_height': image_meter_height
}
tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
# Run tracking
with tqdm(total=len(all_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
for i, key in enumerate(sorted(all_preds.keys())):
if gp: gp(i / len(all_preds), pbar.__str__())
boxes = all_preds[key]
if boxes is not None:
tracker.update(boxes)
else:
tracker.update()
pbar.update(1)
json_data = tracker.finalize(min_length=min_length, min_travel=min_travel)
return json_data
def do_confidence_boost(inference, safe_preds, gp=None, batch_size=BATCH_SIZE, boost_power=1, boost_decay=1, verbose=True):
"""
Takes in the full YOLO detections 'inference' and formatted non-max suppressed detections 'safe_preds'
and boosts the confidence of detections around identified fish that are close in space in neighbouring frames.
"""
if (gp): gp(0, "Confidence Boost...")
# keep predictions to feed them ordered into the Tracker
# TODO: how to deal with large files?
boost_cutoff = 0.01
boost_range = math.floor(math.sqrt(1/boost_decay * math.log(boost_power / boost_cutoff)))
boost_scale = boost_power * math.exp(-boost_decay)
with tqdm(total=len(inference), desc="Running confidence boost", ncols=0, disable=not verbose) as pbar:
for batch_i in range(len(inference)):
if gp: gp(batch_i / len(inference), pbar.__str__())
safe = safe_preds[batch_i]
infer = inference[batch_i]
for i in range(len(safe)):
safe_frame = safe[i]
if len(safe_frame) == 0:
continue
next_batch = inference[batch_i + 1] if batch_i+1 < len(inference) else None
prev_batch = inference[batch_i - 1] if batch_i-1 >= 0 else None
for dt in range(-boost_range, boost_range+1):
if dt == 0: continue
idx = i+dt
temp_frame = None
if idx >= 0 and idx < len(infer):
temp_frame = infer[idx]
elif idx < 0 and prev_batch is not None and -idx >= len(prev_batch):
temp_frame = prev_batch[idx]
elif idx >= len(infer) and next_batch is not None and idx - len(infer) < len(next_batch):
temp_frame = next_batch[idx - len(infer)]
if temp_frame is not None:
boost_frame(safe_frame, temp_frame, dt, power=boost_scale, decay=boost_decay)
pbar.update(1*batch_size)
def boost_frame(safe_frame, base_frame, dt, power=1, decay=1):
"""
Boosts confidence of base_frame based on confidence in safe_frame, iou, and the time difference between frames.
"""
safe_boxes = safe_frame[:, :4]
boxes = xywh2xyxy(base_frame[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)≈
# If running on CPU, you have to convert to double for the .prod() function in box_iou for some reason?
if torch.cuda.is_available():
ious = box_iou(boxes, safe_boxes)
else:
ious = box_iou(boxes.double(), safe_boxes).float()
score = torch.matmul(ious, safe_frame[:, 4])
# score = iou(safe_box, base_box) * confidence(safe_box)
base_frame[:, 4] *= 1 + power*(score)*math.exp(-decay*(dt*dt-1))
return base_frame
# ByteTrack
def do_associative_tracking(low_preds, high_preds, image_meter_width, image_meter_height, reverse=False, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL, verbose=True):
if (gp): gp(0, "Tracking...")
# Initialize tracker
clip_info = {
'start_frame': 0,
'end_frame': len(low_preds),
'image_meter_width': image_meter_width,
'image_meter_height': image_meter_height
}
print("Tracking using Associate")
tracker = Tracker(clip_info, algorithm=Associate, reverse=reverse, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
# Run tracking
with tqdm(total=len(low_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
for i, key in enumerate(sorted(low_preds.keys(), reverse=reverse)):
if gp: gp(i / len(low_preds), pbar.__str__())
low_boxes = low_preds[key]
high_boxes = high_preds[key]
boxes = (low_boxes, high_boxes)
if low_boxes is not None and high_boxes is not None:
tracker.update(boxes)
else:
tracker.update((np.empty((0, 5)), np.empty((0, 5))))
pbar.update(1)
json_data = tracker.finalize(min_length=min_length, min_travel=min_travel)
return json_data
@patch('json.encoder.c_make_encoder', None)
def json_dump_round_float(some_object, out_path, num_digits=4):
"""Write a json file to disk with a specified level of precision.
See: https://gist.github.com/Sukonnik-Illia/ed9b2bec1821cad437d1b8adb17406a3
"""
# saving original method
of = json.encoder._make_iterencode
def inner(*args, **kwargs):
args = list(args)
# fifth argument is float formater which will we replace
fmt_str = '{:.' + str(num_digits) + 'f}'
args[4] = lambda o: fmt_str.format(o)
return of(*args, **kwargs)
with patch('json.encoder._make_iterencode', wraps=inner):
return json.dump(some_object, open(out_path, 'w'), indent=2)
def filter_detection_size(inference, image_meter_width, width, max_length):
outputs = []
for batch in inference:
print("batch")
print(type(batch))
print(batch.shape)
pix2width = image_meter_width/width
width = batch[..., 2]*pix2width
wc = width < max_length
print("wc")
print(type(wc))
print(wc.shape)
bs = batch.shape[0] # batches
output = torch.zeros((bs, 0, 6), device=batch.device)
print("wc")
print(batch.shape)
for xi, x in enumerate(batch):
x = x[wc[xi]] # confidence
print(x.shape)
output[xi, :, :] = x
output = torch.tensor(output)
print("output len", output.shape)
outputs.append(output)
print(len(outputs))
return outputs
def non_max_suppression(
prediction,
image_meter_width,
image_pixel_width,
max_length=1.5,
conf_thres=0.25,
iou_thres=0.45,
max_det=300
):
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
NOTE: SIMPLIFIED FOR SINGLE CLASS DETECTION
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
device = prediction.device
mps = 'mps' in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
xc = prediction[..., 4] > conf_thres # candidates
# width filter
pix2width = image_meter_width/image_pixel_width
width = prediction[..., 2]*pix2width
if max_length > 0:
wc = width < max_length
else:
# If max_length is 0, ignore
wc = width > max_length
# Settings
# min_wh = 2 # (pixels) minimum box width and height
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
redundant = True # require redundant detections
merge = False # use merge-NMS
output = [torch.zeros((0, 6), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Keep boxes that pass confidence threshold
x = x[xc[xi] * wc[xi]] # confidence
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# Box/Mask
box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
mask = x[:, 6:] # zero columns if no masks
# Detections matrix nx6 (xyxy, conf, cls)
conf, j = x[:, 5:6].max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
# Batched NMS
boxes = x[:, :4] # boxes (offset by class), scores
scores = x[:, 4]
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if mps:
output[xi] = output[xi].to(device)
logging = False
return output