Spaces:
Build error
Build error
import argparse | |
import os | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
import time | |
import tensorflow as tf | |
physical_devices = tf.config.experimental.list_physical_devices('GPU') | |
if len(physical_devices) > 0: | |
tf.config.experimental.set_memory_growth(physical_devices[0], True) | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
# deep sort imports | |
from deep_sort import nn_matching | |
from application_util import preprocessing | |
from deep_sort.detection import Detection | |
from deep_sort.tracker import Tracker | |
from _tools_ import generate_detections as gdet | |
# deepsort | |
from mrcnn.mrcnn_color import MRCNN | |
# ocr | |
# from sts.demo.sts import handle_sts | |
def _parse_args(): | |
parser = argparse.ArgumentParser(description="") | |
parser.add_argument("--model", | |
help="detection model", | |
type=str, | |
default="./checkpoint/maskrcnn_signboard_ss.ckpt") | |
parser.add_argument("--input_size", | |
help="input size", | |
type=int, | |
default=1024) | |
parser.add_argument("--score", | |
help="score threshold", | |
type=float, | |
default=0.50) | |
parser.add_argument("--size", | |
help="resize images to", | |
type=int, | |
default=1024) | |
parser.add_argument("--video", | |
help="path to input video or set to 0 for webcam", | |
type=str, | |
default="./samples/demo.mp4") | |
parser.add_argument("--output", | |
help="path to output video", | |
type=str, | |
default="./outputs/demo.mp4") | |
parser.add_argument("--output_format", | |
help="codec used in VideoWriter when saving video to file", | |
type=str, | |
default='mp4v') | |
parser.add_argument("--dont_show", | |
help="dont show video output", | |
type=bool, | |
default=True) | |
parser.add_argument("--info", | |
help="show detailed info of tracked objects", | |
type=bool, | |
default=True) | |
parser.add_argument("--count", | |
help="count objects being tracked on screen", | |
type=bool, | |
default=True) | |
args = parser.parse_args() | |
return args | |
def handle(args): | |
# Definition of the parameters | |
max_cosine_distance = 0.4 | |
nn_budget = None | |
nms_max_overlap = 1.0 | |
# initialize deep sort | |
model_filename = 'checkpoint/signboard_2793.pb' | |
encoder = gdet.create_box_encoder(model_filename, batch_size=1) | |
# calculate cosine distance metric | |
metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget) | |
# initialize tracker | |
tracker = Tracker(metric) | |
# initialize maskrcnn | |
mrcnn = MRCNN(args.model, args.input_size, args.score) | |
# load configuration for object detector | |
video_path = args.video | |
# begin video capture | |
try: | |
vid = cv2.VideoCapture(int(video_path)) | |
except: | |
vid = cv2.VideoCapture(video_path) | |
out = None | |
# get video ready to save locally if flag is set | |
if args.output: | |
# by default VideoCapture returns float instead of int | |
width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(vid.get(cv2.CAP_PROP_FPS)) | |
codec = cv2.VideoWriter_fourcc(*args.output_format) | |
out = cv2.VideoWriter(args.output, codec, fps, (width, height)) | |
frame_num = 0 | |
# while video is running | |
while True: | |
return_value, frame = vid.read() | |
if return_value: | |
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
image = Image.fromarray(image) | |
else: | |
print('Video has ended or failed, try a different video format!') | |
break | |
frame_num +=1 | |
print('Frame #: ', frame_num) | |
start_time = time.time() | |
boxes, scores, class_names, class_ids, class_color = mrcnn.detect_result_(image, min_score=0.5) | |
count = len(class_names) | |
if args.count: | |
cv2.putText(frame, "Objects being tracked: {0}".format(count), (5, 35), cv2.FONT_HERSHEY_COMPLEX_SMALL, 2, (0, 255, 0), 2) | |
print("Objects being tracked: {0}".format(count)) | |
# encode yolo detections and feed to tracker | |
features = encoder(frame, boxes) | |
detections = [Detection(box, score, class_name, feature) for box, score, class_name, feature in zip(boxes, scores, class_names, features)] | |
#initialize color map | |
cmap = plt.get_cmap('tab20b') | |
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)] | |
# run non-maxima supression | |
boxs = np.array([d.tlwh for d in detections]) | |
scores = np.array([d.confidence for d in detections]) | |
classes = np.array([d.class_name for d in detections]) | |
indices = preprocessing.non_max_suppression(boxs, classes, nms_max_overlap, scores) | |
detections = [detections[i] for i in indices] | |
# Call the tracker | |
tracker.predict() | |
tracker.update(detections) | |
# update tracks | |
with open("./outputs/{}.txt".format(frame_num), "a+", encoding="utf-8") as ff: | |
for track in tracker.tracks: | |
if not track.is_confirmed() or track.time_since_update > 1: | |
continue | |
bbox = track.to_tlbr() | |
# crop to ids folder | |
ids_path = "./ids/"+str(track.track_id) | |
if not os.path.isdir(ids_path): | |
os.mkdir(ids_path) | |
crop_ids = frame[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] | |
num_ids = 0 | |
while os.path.isfile(os.path.join(ids_path, str(track.track_id) + "_" + str(frame_num) + "_" + str(num_ids)+".png")): | |
num_ids += 1 | |
final_ids_path = os.path.join(ids_path, str(track.track_id) + "_" + str(frame_num) + "_" + str(num_ids)+".png") | |
cv2.imwrite(final_ids_path, crop_ids) | |
for track in tracker.tracks: | |
if not track.is_confirmed() or track.time_since_update > 1: | |
continue | |
bbox = track.to_tlbr() | |
class_name = track.get_class() | |
# predict ocr | |
crop_ids = frame[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] | |
dict_box_sign_out, dict_rec_sign_out = [], [] # handle_sts(crop_ids) | |
# draw bbox on screen | |
color = colors[int(track.track_id) % len(colors)] | |
color = [i * 255 for i in color] | |
cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), color, 2) | |
cv2.rectangle(frame, (int(bbox[0]), int(bbox[1]-30)), (int(bbox[0])+(len(class_name)+len(str(track.track_id)))*17, int(bbox[1])), color, -1) | |
cv2.putText(frame, class_name + "-" + str(track.track_id),(int(bbox[0]), int(bbox[1]-10)),0, 0.75, (255,255,255),2) | |
dict_rec_sign_out_join = "_".join(dict_rec_sign_out) | |
cv2.putText(frame, dict_rec_sign_out_join, (int(bbox[0]), int(bbox[1]+20)), 0, 0.75, (255, 255, 255), 2) | |
# if enable info flag then print details about each track | |
if args.info: | |
print("Tracker ID: {}, Class: {}, BBox Coords (xmin, ymin, xmax, ymax): {}".format(str(track.track_id), class_name, (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])))) | |
ff.write("{}, {}, {}, {}, {}, {}\n".format(str(track.track_id), int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]), dict_rec_sign_out_join)) | |
ff.close() | |
# calculate frames per second of running detections | |
fps = 1.0 / (time.time() - start_time) | |
print("FPS: %.2f" % fps) | |
result = frame | |
if not args.dont_show: | |
cv2.imshow("Output Video", result) | |
# if output flag is set, save video file | |
if args.output: | |
cv2.imwrite("./outputs/{0}.jpg".format(frame_num), result) | |
out.write(result) | |
if cv2.waitKey(1) & 0xFF == ord('q'): break | |
cv2.destroyAllWindows() | |
def main(): | |
args = _parse_args() | |
handle(args) | |
if __name__ == '__main__': | |
main() | |