natexcvi
Fix shape
3ea82c1 unverified
raw
history blame
2.67 kB
import threading
import cv2
import mediapipe as mp
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
from tensorflow import keras
class Model:
def __init__(self, model_repo_id: str, model_filename: str, hf_token: str):
self.landmark_extractor = LandmarkExtractor()
self.trained_model = keras.models.load_model(
hf_hub_download(
model_repo_id, model_filename, token=hf_token, cache_dir=".cache"
)
)
def predict(self, x: np.ndarray):
return self.trained_model.predict(x)
def preprocess(self, image: bytes) -> np.ndarray:
array_repr = np.asarray(bytearray(image), dtype=np.uint8)
decoded_img = cv2.imdecode(array_repr, flags=cv2.IMREAD_COLOR)
landmarks = self.landmark_extractor.extract_landmarks_flat(decoded_img)
return pd.DataFrame([landmarks]).to_numpy().reshape(1, 478, -1)
@staticmethod
def distance(x1: np.ndarray, x2: np.ndarray):
return np.linalg.norm(x1 - x2, ord=2)
class LandmarkExtractor:
def __init__(self):
self.lock = threading.Lock()
self.face_mesh = mp.solutions.face_mesh.FaceMesh(
static_image_mode=True,
max_num_faces=1,
refine_landmarks=True,
min_detection_confidence=0.5,
)
@staticmethod
def __normalise_landmarks(landmarks):
left_eye_idx = 33
right_eye_idx = 263
left_eye_v = [landmarks[left_eye_idx].x, landmarks[left_eye_idx].y]
right_eye_v = [landmarks[right_eye_idx].x, landmarks[right_eye_idx].y]
xy_norm = np.linalg.solve(
np.array([left_eye_v, right_eye_v]).T,
np.array([[lmk.x, lmk.y] for lmk in landmarks]).T,
).T
for i, lmk in enumerate(landmarks):
lmk.x = xy_norm[i][0]
lmk.y = xy_norm[i][1]
return landmarks
def extract_landmarks(self, img: np.ndarray):
results = self.face_mesh.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
if results.multi_face_landmarks is None:
return None
return self.__normalise_landmarks(results.multi_face_landmarks[0].landmark)
def extract_landmarks_flat(self, img: np.ndarray):
landmarks = self.extract_landmarks(img)
if landmarks is None:
return None
flat_landmarks = {}
for i, landmark in enumerate(landmarks):
flat_landmarks.update(
{
f"{i}_x": landmark.x,
f"{i}_y": landmark.y,
f"{i}_z": landmark.z,
}
)
return flat_landmarks