mpc001's picture
Upload 125 files
09481f3
raw
history blame
1.45 kB
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import warnings
import torchvision
from ibug.face_detection import RetinaFacePredictor
from ibug.face_alignment import FANPredictor
warnings.filterwarnings("ignore")
class LandmarksDetector:
def __init__(self, device="cuda:0", model_name='resnet50'):
self.face_detector = RetinaFacePredictor(
device=device,
threshold=0.8,
model=RetinaFacePredictor.get_model(model_name)
)
self.landmark_detector = FANPredictor(device=device, model=None)
def __call__(self, filename):
video_frames = torchvision.io.read_video(filename, pts_unit='sec')[0].numpy()
landmarks = []
for frame in video_frames:
detected_faces = self.face_detector(frame, rgb=False)
face_points, _ = self.landmark_detector(frame, detected_faces, rgb=True)
if len(detected_faces) == 0:
landmarks.append(None)
else:
max_id, max_size = 0, 0
for idx, bbox in enumerate(detected_faces):
bbox_size = (bbox[2] - bbox[0]) + (bbox[3] - bbox[1])
if bbox_size > max_size:
max_id, max_size = idx, bbox_size
landmarks.append(face_points[max_id])
return landmarks