productcounter / app.py
pedrohc's picture
Update app.py
41ffe92
import supervision
import tqdm
import os
from ultralytics import YOLO
from dataclasses import dataclass
from onemetric.cv.utils.iou import box_iou_batch
from supervision import Point
from supervision import Detections, BoxAnnotator
from supervision import draw_text
from supervision import Color
from supervision import VideoInfo
from supervision import get_video_frames_generator
from supervision import VideoSink
os.system("pip install git+https://github.com/ifzhang/ByteTrack")
from typing import List
import numpy as np
import gradio as gr
from tqdm import tqdm
import yolox
os.system("pip3 install cython_bbox gdown 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'")
os.system("pip3 install -v -e .")
from yolox.tracker.byte_tracker import BYTETracker, STrack
MODEL = "./best.pt"
TARGET_VIDEO_PATH = "test.mp4"
CLASS_ID = [0,1,2,3,4,5,6]
video_examples = [['example.mp4']]
model = YOLO(MODEL)
model.fuse()
classes = CLASS_ID
@dataclass(frozen=True)
class BYTETrackerArgs:
track_thresh: float = 0.25
track_buffer: int = 30
match_thresh: float = 0.8
aspect_ratio_thresh: float = 3.0
min_box_area: float = 1.0
mot20: bool = False
# converts Detections into format that can be consumed by match_detections_with_tracks function
def detections2boxes(detections : Detections) -> np.ndarray:
return np.hstack((
detections.xyxy,
detections.confidence[:, np.newaxis]
))
# converts List[STrack] into format that can be consumed by match_detections_with_tracks function
def tracks2boxes(tracks: List[STrack]) -> np.ndarray:
return np.array([
track.tlbr
for track
in tracks
], dtype=float)
# matches our bounding boxes with predictions
def match_detections_with_tracks(
detections: Detections,
tracks: List[STrack],
) -> Detections:
if not np.any(detections.xyxy) or len(tracks) == 0:
return np.empty((0,))
tracks_boxes = tracks2boxes(tracks=tracks)
iou = box_iou_batch(tracks_boxes, detections.xyxy)
track2detection = np.argmax(iou, axis=1)
tracker_ids = [None] * len(detections)
for tracker_index, detection_index in enumerate(track2detection):
if iou[tracker_index, detection_index] != 0:
tracker_ids[detection_index] = tracks[tracker_index].track_id
return tracker_ids
def ObjectDetection(video_path):
byte_tracker = BYTETracker(BYTETrackerArgs())
video_info = VideoInfo.from_video_path(video_path)
generator = get_video_frames_generator(video_path)
box_annotator = BoxAnnotator(thickness=5, text_thickness=5, text_scale=1)
#polygon
polygon = np.array([[200,300], [200,1420], [880, 1420], [880, 300]])
#zone
zone = supervision.PolygonZone(polygon=polygon, frame_resolution_wh=video_info.resolution_wh)
#zone annotator
zone_annotator = supervision.PolygonZoneAnnotator(zone=zone, color=Color.white(), thickness=4)
# open target video file
with VideoSink(TARGET_VIDEO_PATH, video_info) as sink:
# loop over video frames
for frame in tqdm(generator, total=video_info.total_frames):
results = model(frame)
detections = Detections(
xyxy=results[0].boxes.xyxy.cpu().numpy(),
confidence=results[0].boxes.conf.cpu().numpy(),
class_id=results[0].boxes.cls.cpu().numpy().astype(int)
)
# filtering out detections with unwanted classes
detections = detections[np.isin(detections.class_id, CLASS_ID)]
# tracking detections
tracks = byte_tracker.update(
output_results=detections2boxes(detections = detections),
img_info=frame.shape,
img_size=frame.shape
)
tracker_id = match_detections_with_tracks(detections=detections, tracks=tracks)
detections.tracker_id = np.array(tracker_id)
# filtering out detections without trackers
detections = detections[np.not_equal(detections.tracker_id, None)]
# format custom labels
labels = [
f"#{tracker_id} {classes[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, tracker_id
in detections
]
t = np.unique(detections.class_id, return_counts =True)
# annotate and display frame
mask = zone.trigger(detections=detections)
detections_filtered = detections[mask]
t = np.unique(detections_filtered.class_id, return_counts =True)
for x in zip(t[0], t[1]):
frame = draw_text(background_color=Color.white(), scene=frame, text=' '.join((str(classes[x[0]]), ':', str(x[1]))), text_anchor=Point(x=500, y=1550 + (50 * x[0])), text_scale = 2, text_thickness = 4)
frame = box_annotator.annotate(scene=frame, detections=detections_filtered, labels=labels)
frame = zone_annotator.annotate(scene=frame)
sink.write_frame(frame)
return TARGET_VIDEO_PATH
demo = gr.Interface(fn=ObjectDetection, inputs=gr.Video(), outputs=gr.Video(), examples=video_examples, cache_examples=False)
demo.queue().launch()