Hamidreza-Hashemp's picture
Upload 65 files
cd3346a
raw
history blame
No virus
1.94 kB
import os
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import List
import cv2
cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)
from PIL import Image
from facenet_pytorch.models.mtcnn import MTCNN
from torch.utils.data import Dataset
class VideoFaceDetector(ABC):
def __init__(self, **kwargs) -> None:
super().__init__()
@property
@abstractmethod
def _batch_size(self) -> int:
pass
@abstractmethod
def _detect_faces(self, frames) -> List:
pass
class FacenetDetector(VideoFaceDetector):
def __init__(self, device="cuda:0") -> None:
super().__init__()
self.detector = MTCNN(margin=0,thresholds=[0.85, 0.95, 0.95], device=device)
def _detect_faces(self, frames) -> List:
batch_boxes, *_ = self.detector.detect(frames, landmarks=False)
return [b.tolist() if b is not None else None for b in batch_boxes]
@property
def _batch_size(self):
return 32
class VideoDataset(Dataset):
def __init__(self, videos) -> None:
super().__init__()
self.videos = videos
def __getitem__(self, index: int):
video = self.videos[index]
capture = cv2.VideoCapture(video)
frames_num = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
frames = OrderedDict()
for i in range(frames_num):
capture.grab()
success, frame = capture.retrieve()
if not success:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frame = frame.resize(size=[s // 2 for s in frame.size])
frames[i] = frame
return video, list(frames.keys()), list(frames.values())
def __len__(self) -> int:
return len(self.videos)