human-detector / mivolo /predictor.py
George
upl all codes
b5f33fd
raw
history blame
No virus
2.65 kB
from collections import defaultdict
from typing import Dict, Generator, List, Optional, Tuple
import cv2
import numpy as np
import tqdm
from mivolo.model.mi_volo import MiVOLO
from mivolo.model.yolo_detector import Detector
from mivolo.structures import AGE_GENDER_TYPE, PersonAndFaceResult
class Predictor:
def __init__(self, config, verbose: bool = False):
self.detector = Detector(config.detector_weights, config.device, verbose=verbose)
self.age_gender_model = MiVOLO(
config.checkpoint,
config.device,
half=True,
use_persons=config.with_persons,
disable_faces=config.disable_faces,
verbose=verbose,
)
self.draw = config.draw
def recognize(self, image: np.ndarray) -> Tuple[PersonAndFaceResult, Optional[np.ndarray]]:
detected_objects: PersonAndFaceResult = self.detector.predict(image)
self.age_gender_model.predict(image, detected_objects)
out_im = None
if self.draw:
# plot results on image
out_im = detected_objects.plot()
return detected_objects, out_im
def recognize_video(self, source: str) -> Generator:
video_capture = cv2.VideoCapture(source)
if not video_capture.isOpened():
raise ValueError(f"Failed to open video source {source}")
detected_objects_history: Dict[int, List[AGE_GENDER_TYPE]] = defaultdict(list)
total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
for _ in tqdm.tqdm(range(total_frames)):
ret, frame = video_capture.read()
if not ret:
break
detected_objects: PersonAndFaceResult = self.detector.track(frame)
self.age_gender_model.predict(frame, detected_objects)
current_frame_objs = detected_objects.get_results_for_tracking()
cur_persons: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[0]
cur_faces: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[1]
# add tr_persons and tr_faces to history
for guid, data in cur_persons.items():
# not useful for tracking :)
if None not in data:
detected_objects_history[guid].append(data)
for guid, data in cur_faces.items():
if None not in data:
detected_objects_history[guid].append(data)
detected_objects.set_tracked_age_gender(detected_objects_history)
if self.draw:
frame = detected_objects.plot()
yield detected_objects_history, frame