aero-recognize / inference.py
chiyoi's picture
Tidy
0589df6
import tensorflow as tf
import numpy as np
from imgviz import instances2rgb, label2rgb
from configuration import Config
# detections: (classes: list of class_name, boxes: list of [x1, y1, x2, y2])
# actions: list of f'{action_name}: {confidence}'
def format_frame(frame, config: Config):
frame = tf.image.convert_image_dtype(frame, tf.float32)
frame = tf.image.resize_with_pad(frame, *config.frame_size)
return frame
def detect_object(detector, frame):
result = detector(frame, classes=4, verbose=False)[0]
classes = result.boxes.cls.numpy()
boxes = result.boxes.xyxy.numpy()
detections = (
[result.names[i].capitalize() for i in classes],
boxes)
return detections
def classify_action(classifier, frames, id_to_name):
actions = []
frames = np.array(frames)
frames = tf.expand_dims(frames, 0)
y = classifier(frames)
confidences = tf.squeeze(y).numpy()
for (class_id, confidence) in enumerate(confidences):
other_class_id = 2
if confidence > 0.3 and class_id != other_class_id:
actions.append(f'{id_to_name[class_id]}: {confidence:.2f}')
return actions
def draw_boxes(frame, detections, actions, do_classify):
(classes, boxes) = detections
max_area = 0
max_area_id = 0
for i, box in enumerate(boxes):
area = (box[3] - box[1]) * (box[2] - box[0])
if area > max_area:
max_area = area
max_area_id = i
labels = [0 for _ in classes]
colormap = [(0x39, 0xc5, 0xbb)]
line_width = 2
if not do_classify:
captions = classes
else:
captions = [
f'{class_name}\n' + '\n'.join(actions if i == max_area_id else [])
for (i, class_name) in enumerate(classes)]
bboxes = [
[box[1], box[0], box[3], box[2]]
for box in boxes]
frame = instances2rgb(
frame,
labels=labels,
captions=captions,
bboxes=bboxes,
colormap=colormap,
font_size=20,
line_width=line_width)
return frame
def draw_classes(frame, actions):
height, width, _ = frame.shape
labels = np.zeros((height, width), dtype=int)
label_names = ['\n'.join(actions)]
frame = label2rgb(
label=labels,
image=frame,
label_names=label_names,
alpha=0)
return frame
def FrameProcessor(detector, classifier, config: Config):
current_frame = 0
frames = []
actions = []
detections = ([], [])
def process_frame(frame):
nonlocal current_frame, frames, actions, detections
current_frame += 1
if current_frame % config.classify_action_frame_steps == 0:
frames.append(format_frame(frame))
if current_frame % config.detect_object_frame_steps == 0:
print(f'Detect object: Frame {current_frame}')
detections = detect_object(detector, frame)
if len(frames) == config.classify_action_num_frames:
print(f'Classify action: Until frame {current_frame}')
actions = classify_action(classifier, frames)
frames = []
frame = draw_boxes(frame, detections, actions)
return frame
return process_frame