import os import sys from importlib import import_module, invalidate_caches from importlib.util import module_from_spec, spec_from_file_location from tempfile import TemporaryDirectory import cv2 import mediapipe as mp import numpy as np import plotly.express as px import requests import torch from git import Repo from huggingface_hub import hf_hub_download class FECNetModel: def __init__(self, hf_token: str) -> None: self.hf_token = hf_token repo_dir = TemporaryDirectory() Repo.clone_from( "https://github.com/AmirSh15/FECNet.git", repo_dir.name, ) invalidate_caches() sys.path.append(repo_dir.name) fecnet_module_path = os.path.join(repo_dir.name, "models", "FECNet.py") with open(fecnet_module_path, "r") as f: content = f.read() content = content.replace( "cuda", "cpu", ) with open(fecnet_module_path, "w") as f: f.write(content) spec = spec_from_file_location("FECNet", fecnet_module_path) fecnet_module = module_from_spec(spec) # type: ignore spec.loader.exec_module(fecnet_module) # type: ignore self.model = self.__load_model( self.__download_weights(repo_dir.name), fecnet_module.FECNet ) self.face_detector = mp.solutions.face_detection.FaceDetection( min_detection_confidence=0.5 ) def __download_weights(self, model_dir: str) -> str: model_path = hf_hub_download( "natexcvi/pretrained-fecnet", "fecnet.pt", token=self.hf_token, ) return model_path def __load_model(self, model_path: str, model_class): model = model_class(pretrained=False) model_weights = torch.load(model_path, map_location=torch.device("cpu")) model.load_state_dict(model_weights) model.eval() return model def predict(self, image: np.ndarray): pred = self.model(image) return pred def distance(a, b): return np.linalg.norm(a - b) def embed_image(self, image, crop_face: bool = False) -> np.ndarray: image = cv2.imdecode(image, cv2.IMREAD_COLOR) if crop_face: image = self.extract_face(image) # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.resize(image, (224, 224)) image = np.transpose(image, (2, 0, 1)) image = np.expand_dims(image, axis=0) image = torch.from_numpy(image.astype(np.float32)) pred = self.predict(image) return pred.detach().numpy() def extract_face(self, image): mp_face_detection = mp.solutions.face_detection # Convert the image to RGB image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Initialize the face detection model # Run the face detection model on the image results = self.face_detection.process(image) # If a face is detected, crop the image to the face box if results.detections: for detection in results.detections: x, y, w, h = ( int( detection.location_data.relative_bounding_box.xmin * image.shape[1] ), int( detection.location_data.relative_bounding_box.ymin * image.shape[0] ), int( detection.location_data.relative_bounding_box.width * image.shape[1] ), int( detection.location_data.relative_bounding_box.height * image.shape[0] ), ) cropped_image = image[y : y + h, x : x + w] return cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR)