Spaces:
Build error
Build error
from src.tracker.mrcnn.mrcnn_color import MRCNN | |
from src.tracker._tools_ import generate_detections as gdet | |
from src.tracker.deep_sort.tracker import Tracker | |
from src.tracker.deep_sort.detection import Detection | |
from src.tracker.application_util import preprocessing | |
from src.tracker.deep_sort import nn_matching | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import tensorflow as tf | |
import time | |
import ffmpeg | |
import os | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
physical_devices = tf.config.experimental.list_physical_devices('GPU') | |
if len(physical_devices) > 0: | |
tf.config.experimental.set_memory_growth(physical_devices[0], True) | |
# deep sort imports | |
# deepsort | |
def check_rotation(path_video_file): | |
print(path_video_file) | |
meta_dict = ffmpeg.probe(path_video_file) | |
try: | |
if int(meta_dict['streams'][0]['tags']['rotate']) == 90: | |
return cv2.ROTATE_90_CLOCKWISE | |
elif int(meta_dict['streams'][0]['tags']['rotate']) == 180: | |
return cv2.ROTATE_180 | |
elif int(meta_dict['streams'][0]['tags']['rotate']) == 270: | |
return cv2.ROTATE_90_COUNTERCLOCKWISE | |
except: | |
return None | |
def correct_rotation(frame, rotateCode): | |
return cv2.rotate(frame, rotateCode) | |
class SignboardTracker(): | |
def __init__(self, | |
detector_checkpoint: str = "./checkpoints/ss/ss.ckpt", | |
input_size: int = 1024, | |
score: float = 0.7, | |
size: int = 1024, | |
video: str = "", | |
output: str = "", | |
output_format: str = "", | |
dont_show: bool = True, | |
info: bool = True, | |
count: bool = True, | |
max_cosine_distance: float = 0.4, | |
nn_budget: None = None, | |
nms_max_overlap: float = 1.0, | |
tracker_checkpoint: str = "./checkpoints/tracker/signboard_2793.pb" | |
) -> None: | |
self.detector_checkpoint = detector_checkpoint | |
self.input_size = input_size | |
self.score = score | |
self.size = size | |
self.video = video | |
self.output = output | |
self.output_format = output_format | |
self.dont_show = dont_show | |
self.info = info | |
self.count = count | |
self.max_cosine_distance = max_cosine_distance | |
self.nn_budget = nn_budget | |
self.nms_max_overlap = nms_max_overlap | |
self.tracker_checkpoint = tracker_checkpoint | |
self.load_tracker() | |
self.load_detector() | |
def load_tracker(self): | |
self.encoder = gdet.create_box_encoder(self.tracker_checkpoint, batch_size=1) | |
metric = nn_matching.NearestNeighborDistanceMetric("cosine", self.max_cosine_distance, self.nn_budget) | |
self.tracker = Tracker(metric) | |
def load_detector(self): | |
self.mrcnn = MRCNN(self.detector_checkpoint, self.input_size, self.score) | |
def inference_signboard(self, fps_target, video_path, output, output_format, output_frames): | |
results = {} | |
results_ = {} | |
rotateCode = check_rotation(video_path) | |
try: | |
vid = cv2.VideoCapture(int(video_path)) | |
except: | |
vid = cv2.VideoCapture(video_path) | |
out = None | |
# get video ready to save locally if flag is set | |
if output: | |
# by default VideoCapture returns float instead of int | |
width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(vid.get(cv2.CAP_PROP_FPS)) | |
codec = cv2.VideoWriter_fourcc(*output_format) | |
out = cv2.VideoWriter(output, codec, fps_target, (width, height)) | |
tg = (fps - (fps%fps_target)) / fps_target | |
frame_num = 0 | |
# while video is running | |
while True: | |
return_value, frame = vid.read() | |
# original = frame | |
if return_value: | |
frame_num += 1 | |
if rotateCode is not None: | |
frame = correct_rotation(frame, rotateCode) | |
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
image = Image.fromarray(image) | |
else: | |
break | |
if (frame_num%tg)==0: | |
if str(frame_num) not in results: | |
results[str(frame_num)] = [] | |
start_time = time.time() | |
boxes, scores, class_names, class_ids, class_color = self.mrcnn.detect_result_(image, min_score=0.4) | |
count = len(class_names) | |
# encode yolo detections and feed to tracker | |
features = self.encoder(frame, boxes) | |
detections = [Detection(box, score, class_name, feature) for box, score, class_name, feature in zip(boxes, scores, class_names, features)] | |
# initialize color map | |
cmap = plt.get_cmap('tab20b') | |
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)] | |
# run non-maxima supression | |
boxs = np.array([d.tlwh for d in detections]) | |
scores = np.array([d.confidence for d in detections]) | |
classes = np.array([d.class_name for d in detections]) | |
indices = preprocessing.non_max_suppression(boxs, classes, self.nms_max_overlap, scores) | |
detections = [detections[i] for i in indices] | |
# Call the tracker | |
self.tracker.predict() | |
self.tracker.update(detections) | |
# update tracks | |
# with open(f"{output_frames}/{frame_num}.txt", "a+", encoding="utf-8") as ff: | |
for track in self.tracker.tracks: | |
if not track.is_confirmed() or track.time_since_update > 1: | |
continue | |
bbox = track.to_tlbr() | |
class_name = track.get_class() | |
# crop to ids folder | |
ids_path = f"{output_frames}" | |
# print(ids_path) | |
if not os.path.isdir(ids_path): | |
os.makedirs(ids_path) | |
crop_ids = frame[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] | |
# num_ids = 0 | |
# if os.path.isfile(os.path.join(ids_path, str(track.track_id) + "_" + str(frame_num)+".png")): | |
# num_ids += 1 | |
final_ids_path = os.path.join(ids_path, str(track.track_id) + "_" + str(frame_num)+".png") | |
try: | |
cv2.imwrite(final_ids_path, crop_ids) | |
except Exception as e: | |
print(e) | |
# draw bbox on screen | |
color = colors[int(track.track_id) % len(colors)] | |
color = [i * 255 for i in color] | |
cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), color, 2) | |
cv2.rectangle(frame, (int(bbox[0]), int(bbox[1]-30)), (int(bbox[0])+(len(class_name)+len(str(track.track_id)))*17, int(bbox[1])), color, -1) | |
cv2.putText(frame, class_name + "-" + str(track.track_id), (int(bbox[0]), int(bbox[1]-10)), 0, 0.75, (255, 255, 255), 2) | |
# if enable info flag then print details about each track | |
results[str(frame_num)].append({ | |
"id": track.track_id, | |
"class": class_name, | |
"box": [int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])] | |
}) | |
if str(track.track_id) not in results_: | |
results_[str(track.track_id)] = [] | |
results_[str(track.track_id)].append({ | |
"frame": frame_num, | |
"class": class_name, | |
"box": [int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])] | |
}) | |
# ff.close() | |
# calculate frames per second of running detections | |
# fps = 1.0 / (time.time() - start_time) | |
result = frame | |
# if output flag is set, save video file | |
if output: | |
cv2.imwrite(f"{output}/{frame_num}.jpg", result) | |
out.write(result) | |
if cv2.waitKey(1) & 0xFF == ord('q'): | |
break | |
cv2.destroyAllWindows() | |
return [results, results_] |