Spaces:
Runtime error
Runtime error
import project_path | |
import torch | |
from tqdm import tqdm | |
from functools import partial | |
import numpy as np | |
import json | |
import time | |
from unittest.mock import patch | |
# assumes yolov5 on sys.path | |
from lib.yolov5.models.experimental import attempt_load | |
from lib.yolov5.utils.torch_utils import select_device | |
from lib.yolov5.utils.general import clip_boxes, scale_boxes, xywh2xyxy | |
from lib.yolov5.utils.metrics import box_iou | |
import torch | |
import torchvision | |
from lib.fish_eye.tracker import Tracker | |
### Configuration options | |
WEIGHTS = 'models/v5m_896_300best.pt' | |
# will need to configure these based on GPU hardware | |
BATCH_SIZE = 32 | |
CONF_THRES = 0.3 # detection | |
NMS_IOU = 0.3 # NMS IOU | |
MIN_LENGTH = 0.3 # minimum fish length, in meters | |
MAX_AGE = 20 # time until missing fish get's new id | |
IOU_THRES = 0.01 # IOU threshold for tracking | |
MIN_HITS = 11 # minimum number of frames with a specific fish for it to count | |
### | |
def norm(bbox, w, h): | |
""" | |
Normalize a bounding box. | |
Args: | |
bbox: list of length 4. Can be [x,y,w,h] or [x0,y0,x1,y1] | |
w: image width | |
h: image height | |
""" | |
bb = bbox.copy() | |
bb[0] /= w | |
bb[1] /= h | |
bb[2] /= w | |
bb[3] /= h | |
return bb | |
def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, weights=WEIGHTS): | |
model, device = setup_model(weights) | |
load = False | |
save = False | |
if load: | |
with open('static/example/inference_output.json', 'r') as f: | |
json_object = json.load(f) | |
inference = json_object['inference'] | |
width = json_object['width'] | |
height = json_object['height'] | |
else: | |
inference, width, height = do_detection(dataloader, model, device, gp=gp) | |
if save: | |
json_object = { | |
'inference': inference, | |
'width': width, | |
'height': height | |
} | |
json_text = json.dumps(json_object, indent=4) | |
with open('static/example/inference_output.json', 'w') as f: | |
f.write(json_text) | |
return | |
all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, gp=gp) | |
results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gp) | |
return results | |
def setup_model(weights_fp=WEIGHTS, imgsz=896, batch_size=32): | |
if torch.cuda.is_available(): | |
device = select_device('0', batch_size=batch_size) | |
else: | |
print("CUDA not available. Using CPU inference.") | |
device = select_device('cpu', batch_size=batch_size) | |
# Setup model for inference | |
model = attempt_load(weights_fp, device=device) | |
half = device.type != 'cpu' # half precision only supported on CUDA | |
if half: | |
model.half() | |
model.eval() | |
# Create dataloader for batched inference | |
img = torch.zeros((1, 3, imgsz, imgsz), device=device) | |
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once | |
return model, device | |
def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE, verbose=True): | |
""" | |
Args: | |
frames_dir: a directory containing frames to be evaluated | |
image_meter_width: the width of each image, in meters (used for fish length calculation) | |
gp: a callback function which takes as input 1 parameter, (int) percent complete | |
prep_for_marking: re-index fish for manual marking output | |
""" | |
if (gp): gp(0, "Detection...") | |
inference = [] | |
# Run detection | |
with tqdm(total=len(dataloader)*batch_size, desc="Running detection", ncols=0, disable=not verbose) as pbar: | |
for batch_i, (img, _, shapes) in enumerate(dataloader): | |
if gp: gp(batch_i / len(dataloader), pbar.__str__()) | |
img = img.to(device, non_blocking=True) | |
img = img.half() if device.type != 'cpu' else img.float() # uint8 to fp16/32 | |
img /= 255.0 # 0 - 255 to 0.0 - 1.0 | |
size = tuple(img.shape) | |
nb, _, height, width = size # batch size, channels, height, width | |
# Run model & NMS | |
with torch.no_grad(): | |
inf_out, _ = model(img, augment=False) | |
inference.append(inf_out) | |
pbar.update(1*batch_size) | |
return inference, width, height | |
def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, verbose=True): | |
""" | |
Args: | |
frames_dir: a directory containing frames to be evaluated | |
image_meter_width: the width of each image, in meters (used for fish length calculation) | |
gp: a callback function which takes as input 1 parameter, (int) percent complete | |
prep_for_marking: re-index fish for manual marking output | |
""" | |
if (gp): gp(0, "Suppression...") | |
# keep predictions to feed them ordered into the Tracker | |
# TODO: how to deal with large files? | |
all_preds = {} | |
with tqdm(total=len(dataloader)*batch_size, desc="Running suppression", ncols=0, disable=not verbose) as pbar: | |
for batch_i, (img, _, shapes) in enumerate(dataloader): | |
if gp: gp(batch_i / len(dataloader), pbar.__str__()) | |
inf_out = inference[batch_i] | |
with torch.no_grad(): | |
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres) | |
# Format results | |
for si, pred in enumerate(output): | |
# Clip boxes to image bounds and resize to input shape | |
clip_boxes(pred, (height, width)) | |
box = pred[:, :4].clone() # xyxy | |
confs = pred[:, 4].clone().tolist() | |
scale_boxes(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape | |
# get boxes into tracker input format - normalized xyxy with confidence score | |
# confidence score currently not used by tracker; set to 1.0 | |
boxes = None | |
if box.shape[0]: | |
real_width = shapes[si][0][1] | |
real_height = shapes[si][0][0] | |
do_norm = partial(norm, w=shapes[si][0][1], h=shapes[si][0][0]) | |
normed = list((map(do_norm, box[:, :4].tolist()))) | |
boxes = np.stack([ [*bb, conf] for bb, conf in zip(normed, confs) ]) | |
frame_num = (batch_i, si) | |
all_preds[frame_num] = boxes | |
pbar.update(1*batch_size) | |
return all_preds, real_width, real_height | |
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True): | |
if (gp): gp(0, "Tracking...") | |
# Initialize tracker | |
clip_info = { | |
'start_frame': 0, | |
'end_frame': len(all_preds), | |
'image_meter_width': image_meter_width, | |
'image_meter_height': image_meter_height | |
} | |
tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits) | |
# Run tracking | |
with tqdm(total=len(all_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar: | |
for i, key in enumerate(sorted(all_preds.keys())): | |
if gp: gp(i / len(all_preds), pbar.__str__()) | |
boxes = all_preds[key] | |
if boxes is not None: | |
tracker.update(boxes) | |
else: | |
tracker.update() | |
pbar.update(1) | |
json_data = tracker.finalize(min_length=min_length) | |
return json_data | |
def json_dump_round_float(some_object, out_path, num_digits=4): | |
"""Write a json file to disk with a specified level of precision. | |
See: https://gist.github.com/Sukonnik-Illia/ed9b2bec1821cad437d1b8adb17406a3 | |
""" | |
# saving original method | |
of = json.encoder._make_iterencode | |
def inner(*args, **kwargs): | |
args = list(args) | |
# fifth argument is float formater which will we replace | |
fmt_str = '{:.' + str(num_digits) + 'f}' | |
args[4] = lambda o: fmt_str.format(o) | |
return of(*args, **kwargs) | |
with patch('json.encoder._make_iterencode', wraps=inner): | |
return json.dump(some_object, open(out_path, 'w'), indent=2) | |
def non_max_suppression( | |
prediction, | |
conf_thres=0.25, | |
iou_thres=0.45, | |
max_det=300, | |
): | |
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections | |
Returns: | |
list of detections, on (n,6) tensor per image [xyxy, conf, cls] | |
""" | |
# Checks | |
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' | |
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' | |
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out) | |
prediction = prediction[0] # select only inference output | |
device = prediction.device | |
mps = 'mps' in device.type # Apple MPS | |
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS | |
prediction = prediction.cpu() | |
bs = prediction.shape[0] # batch size | |
xc = prediction[..., 4] > conf_thres # candidates | |
# Settings | |
# min_wh = 2 # (pixels) minimum box width and height | |
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() | |
redundant = True # require redundant detections | |
merge = False # use merge-NMS | |
output = [torch.zeros((0, 6), device=prediction.device)] * bs | |
for xi, x in enumerate(prediction): # image index, image inference | |
# Keep boxes that pass confidence threshold | |
x = x[xc[xi]] # confidence | |
# If none remain process next image | |
if not x.shape[0]: | |
continue | |
# Compute conf | |
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf | |
# Box/Mask | |
box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2) | |
mask = x[:, 6:] # zero columns if no masks | |
# Detections matrix nx6 (xyxy, conf, cls) | |
conf, j = x[:, 5:6].max(1, keepdim=True) | |
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] | |
# Check shape | |
n = x.shape[0] # number of boxes | |
if not n: # no boxes | |
continue | |
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes | |
# Batched NMS | |
boxes = x[:, :4] # boxes (offset by class), scores | |
scores = x[:, 4] | |
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS | |
i = i[:max_det] # limit detections | |
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) | |
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) | |
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix | |
weights = iou * scores[None] # box weights | |
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes | |
if redundant: | |
i = i[iou.sum(1) > 1] # require redundancy | |
output[xi] = x[i] | |
if mps: | |
output[xi] = output[xi].to(device) | |
logging = False | |
return output |