natexcvi
Add support for face extraction
607801a unverified
import os
import sys
from importlib import import_module, invalidate_caches
from importlib.util import module_from_spec, spec_from_file_location
from tempfile import TemporaryDirectory
import cv2
import mediapipe as mp
import numpy as np
import plotly.express as px
import requests
import torch
from git import Repo
from huggingface_hub import hf_hub_download
class FECNetModel:
def __init__(self, hf_token: str) -> None:
self.hf_token = hf_token
repo_dir = TemporaryDirectory()
Repo.clone_from(
"https://github.com/AmirSh15/FECNet.git",
repo_dir.name,
)
invalidate_caches()
sys.path.append(repo_dir.name)
fecnet_module_path = os.path.join(repo_dir.name, "models", "FECNet.py")
with open(fecnet_module_path, "r") as f:
content = f.read()
content = content.replace(
"cuda",
"cpu",
)
with open(fecnet_module_path, "w") as f:
f.write(content)
spec = spec_from_file_location("FECNet", fecnet_module_path)
fecnet_module = module_from_spec(spec) # type: ignore
spec.loader.exec_module(fecnet_module) # type: ignore
self.model = self.__load_model(
self.__download_weights(repo_dir.name), fecnet_module.FECNet
)
self.face_detector = mp.solutions.face_detection.FaceDetection(
min_detection_confidence=0.5
)
def __download_weights(self, model_dir: str) -> str:
model_path = hf_hub_download(
"natexcvi/pretrained-fecnet",
"fecnet.pt",
token=self.hf_token,
)
return model_path
def __load_model(self, model_path: str, model_class):
model = model_class(pretrained=False)
model_weights = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(model_weights)
model.eval()
return model
def predict(self, image: np.ndarray):
pred = self.model(image)
return pred
def distance(a, b):
return np.linalg.norm(a - b)
def embed_image(self, image, crop_face: bool = False) -> np.ndarray:
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
if crop_face:
image = self.extract_face(image)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (224, 224))
image = np.transpose(image, (2, 0, 1))
image = np.expand_dims(image, axis=0)
image = torch.from_numpy(image.astype(np.float32))
pred = self.predict(image)
return pred.detach().numpy()
def extract_face(self, image):
mp_face_detection = mp.solutions.face_detection
# Convert the image to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Initialize the face detection model
# Run the face detection model on the image
results = self.face_detection.process(image)
# If a face is detected, crop the image to the face box
if results.detections:
for detection in results.detections:
x, y, w, h = (
int(
detection.location_data.relative_bounding_box.xmin
* image.shape[1]
),
int(
detection.location_data.relative_bounding_box.ymin
* image.shape[0]
),
int(
detection.location_data.relative_bounding_box.width
* image.shape[1]
),
int(
detection.location_data.relative_bounding_box.height
* image.shape[0]
),
)
cropped_image = image[y : y + h, x : x + w]
return cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR)