fisheye-experimental / InferenceConfig.py
oskarastrom's picture
Bug fixes
d9c7dce
raw
history blame
No virus
3.3 kB
from enum import Enum
class TrackerType(Enum):
NONE = 0
CONF_BOOST = 1
BYTETRACK = 2
def toString(val):
if val == TrackerType.NONE: return "None"
if val == TrackerType.CONF_BOOST: return "Confidence Boost"
if val == TrackerType.BYTETRACK: return "ByteTrack"
### Configuration options
WEIGHTS = 'models/v5m_896_300best.pt'
# will need to configure these based on GPU hardware
BATCH_SIZE = 32
CONF_THRES = 0.05 # detection
NMS_IOU = 0.25 # NMS IOU
MAX_AGE = 20 # time until missing fish get's new id
MIN_HITS = 11 # minimum number of frames with a specific fish for it to count
MIN_LENGTH = 0.3 # minimum fish length, in meters
MAX_LENGTH = 0 # maximum fish length, in meters
IOU_THRES = 0.01 # IOU threshold for tracking
MIN_TRAVEL = 0 # Minimum distance a track has to travel
DEFAULT_TRACKER = TrackerType.BYTETRACK
class InferenceConfig:
def __init__(self,
weights=WEIGHTS, conf_thresh=CONF_THRES, nms_iou=NMS_IOU,
min_hits=MIN_HITS, max_age=MAX_AGE, min_length=MIN_LENGTH, max_length=MAX_LENGTH, min_travel=MIN_TRAVEL):
self.weights = weights
self.conf_thresh = conf_thresh
self.nms_iou = nms_iou
self.min_hits = min_hits
self.max_age = max_age
self.min_length = min_length
self.max_length = max_length
self.min_travel = min_travel
self.associative_tracker = DEFAULT_TRACKER
self.boost_power = 2
self.boost_decay = 0.1
self.byte_low_conf = 0.1
self.byte_high_conf = 0.3
def enable_sort_track(self):
self.associative_tracker = TrackerType.NONE
def enable_conf_boost(self, power, decay):
self.associative_tracker = TrackerType.CONF_BOOST
self.boost_power = power
self.boost_decay = decay
def enable_byte_track(self, low, high):
self.associative_tracker = TrackerType.BYTETRACK
self.byte_low_conf = low
self.byte_high_conf = high
def find_model(self, model_list):
print("weights", self.weights)
for model_name in model_list:
print("Path", model_list[model_name], "->", model_name)
if model_list[model_name] == self.weights:
return model_name
print("not found")
return None
def to_dict(self):
dict = {
'weights': self.weights,
'nms_iou': self.nms_iou,
'min_hits': self.min_hits,
'max_age': self.max_age,
'min_length': self.min_length,
'min_travel': self.min_travel,
}
# Add tracker specific parameters
if (self.associative_tracker == TrackerType.BYTETRACK):
dict['tracker'] = "ByteTrack"
dict['byte_low_conf'] = self.byte_low_conf
dict['byte_high_conf'] = self.byte_high_conf
elif (self.associative_tracker == TrackerType.CONF_BOOST):
dict['tracker'] = "Confidence Boost"
dict['conf_thresh'] = self.conf_thresh
dict['boost_power'] = self.boost_power
dict['boost_decay'] = self.boost_decay
elif (self.associative_tracker == TrackerType.NONE):
dict['tracker'] = "None"
dict['conf_thresh'] = self.conf_thresh
return dict