natexcvi
Impl. + tests
b87deef unverified
raw
history blame
5.75 kB
import threading
import cv2
import mediapipe as mp
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
from tensorflow import keras
class Model:
def __init__(self, model_repo_id: str, model_filename: str, hf_token: str):
self.landmark_extractor = LandmarkExtractor()
self.trained_model = self.create_prod_model()
self.trained_model.compile(
keras.optimizers.Adam(0.0005),
loss=self.triplet_loss_init(0.2),
metrics=[self.triplet_accuracy],
)
custom_objects = {
"triplet_loss": self.triplet_loss_init(0.2),
"triplet_accuracy": self.triplet_accuracy,
"K": keras.backend,
"keras": keras,
}
with keras.utils.custom_object_scope(custom_objects):
weights = keras.models.load_model(
hf_hub_download(model_repo_id, model_filename, token=hf_token)
).get_weights()
self.trained_model.set_weights(weights)
@staticmethod
def fec_net(inputs):
x = keras.layers.Dense(1024, activation="relu")(inputs)
x = keras.layers.Dense(512, activation="relu")(x)
x = keras.layers.Dense(512, activation="relu")(x)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(16)(x)
outputs = keras.layers.Lambda(keras.backend.l2_normalize)(x)
return outputs
@classmethod
def create_prod_model(cls):
inputs = keras.layers.Input(shape=(478 * 3,))
outputs = cls.fec_net(inputs)
return keras.Model(inputs=inputs, outputs=outputs)
@staticmethod
def triplet_loss_init(alpha):
def triplet_loss(y_true, y_pred):
dimensions = 16
batch_size = y_pred.shape[0]
s1 = y_pred[:, :dimensions]
s2 = y_pred[:, dimensions : 2 * dimensions]
d = y_pred[:, 2 * dimensions :]
s1_s2 = keras.backend.sum(keras.backend.square(s1 - s2), axis=1)
s1_d = keras.backend.sum(keras.backend.square(s1 - d), axis=1)
s2_d = keras.backend.sum(keras.backend.square(s2 - d), axis=1)
loss = keras.backend.maximum(
0.0, s1_s2 - s1_d + alpha
) + keras.backend.maximum(0.0, s1_s2 - s2_d + alpha)
loss = keras.backend.mean(loss)
return loss
return triplet_loss
@staticmethod
def triplet_accuracy(y_true, y_pred):
dimensions = 16
s1 = y_pred[:, :dimensions]
s2 = y_pred[:, dimensions : 2 * dimensions]
d = y_pred[:, 2 * dimensions :]
s1_s2 = keras.backend.sqrt(
keras.backend.sum(keras.backend.square(s1 - s2), axis=1)
)
s1_d = keras.backend.sqrt(
keras.backend.sum(keras.backend.square(s1 - d), axis=1)
)
s2_d = keras.backend.sqrt(
keras.backend.sum(keras.backend.square(s2 - d), axis=1)
)
s1_match = keras.backend.less(s1_s2, s1_d)
s2_match = keras.backend.less(s1_s2, s2_d)
match = keras.backend.cast(
keras.backend.all(keras.backend.stack([s1_match, s2_match]), axis=0),
"float32",
)
acc = keras.backend.mean(match, axis=0)
return acc
@classmethod
def load_model(cls, model_path: str):
custom_objects = {
"triplet_loss": cls.triplet_loss_init(0.2),
"triplet_accuracy": cls.triplet_accuracy,
}
with keras.utils.custom_object_scope(custom_objects):
return keras.models.load_model(model_path)
def predict(self, x: np.ndarray):
return self.trained_model.predict(x)
def preprocess(self, image: bytes) -> np.ndarray:
array_repr = np.asarray(bytearray(image), dtype=np.uint8)
decoded_img = cv2.imdecode(array_repr, flags=cv2.IMREAD_COLOR)
landmarks = self.landmark_extractor.extract_landmarks_flat(decoded_img)
return pd.DataFrame([landmarks]).to_numpy()
@staticmethod
def distance(x1: np.ndarray, x2: np.ndarray):
return np.linalg.norm(x1 - x2, ord=2)
class LandmarkExtractor:
def __init__(self):
self.lock = threading.Lock()
self.face_mesh = mp.solutions.face_mesh.FaceMesh(
static_image_mode=True,
max_num_faces=1,
refine_landmarks=True,
min_detection_confidence=0.5,
)
@staticmethod
def __normalise_landmarks(landmarks):
left_eye_idx = 33
right_eye_idx = 263
left_eye_v = [landmarks[left_eye_idx].x, landmarks[left_eye_idx].y]
right_eye_v = [landmarks[right_eye_idx].x, landmarks[right_eye_idx].y]
xy_norm = np.linalg.solve(
np.array([left_eye_v, right_eye_v]).T,
np.array([[lmk.x, lmk.y] for lmk in landmarks]).T,
).T
for i, lmk in enumerate(landmarks):
lmk.x = xy_norm[i][0]
lmk.y = xy_norm[i][1]
return landmarks
def extract_landmarks(self, img: np.ndarray):
results = self.face_mesh.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
if results.multi_face_landmarks is None:
return None
return self.__normalise_landmarks(results.multi_face_landmarks[0].landmark)
def extract_landmarks_flat(self, img: np.ndarray):
landmarks = self.extract_landmarks(img)
if landmarks is None:
return None
flat_landmarks = {}
for i, landmark in enumerate(landmarks):
flat_landmarks.update(
{
f"{i}_x": landmark.x,
f"{i}_y": landmark.y,
f"{i}_z": landmark.z,
}
)
return flat_landmarks