Spaces:
Build error
Build error
natexcvi
commited on
Commit
•
607801a
1
Parent(s):
cad80c7
Add support for face extraction
Browse files- model/fecnet.py +43 -2
- routers/fecnet_router.py +5 -3
model/fecnet.py
CHANGED
@@ -5,6 +5,7 @@ from importlib.util import module_from_spec, spec_from_file_location
|
|
5 |
from tempfile import TemporaryDirectory
|
6 |
|
7 |
import cv2
|
|
|
8 |
import numpy as np
|
9 |
import plotly.express as px
|
10 |
import requests
|
@@ -40,6 +41,10 @@ class FECNetModel:
|
|
40 |
self.__download_weights(repo_dir.name), fecnet_module.FECNet
|
41 |
)
|
42 |
|
|
|
|
|
|
|
|
|
43 |
def __download_weights(self, model_dir: str) -> str:
|
44 |
model_path = hf_hub_download(
|
45 |
"natexcvi/pretrained-fecnet",
|
@@ -62,12 +67,48 @@ class FECNetModel:
|
|
62 |
def distance(a, b):
|
63 |
return np.linalg.norm(a - b)
|
64 |
|
65 |
-
def embed_image(self, image) -> np.ndarray:
|
66 |
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
|
67 |
-
|
|
|
|
|
68 |
image = cv2.resize(image, (224, 224))
|
69 |
image = np.transpose(image, (2, 0, 1))
|
70 |
image = np.expand_dims(image, axis=0)
|
71 |
image = torch.from_numpy(image.astype(np.float32))
|
72 |
pred = self.predict(image)
|
73 |
return pred.detach().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from tempfile import TemporaryDirectory
|
6 |
|
7 |
import cv2
|
8 |
+
import mediapipe as mp
|
9 |
import numpy as np
|
10 |
import plotly.express as px
|
11 |
import requests
|
|
|
41 |
self.__download_weights(repo_dir.name), fecnet_module.FECNet
|
42 |
)
|
43 |
|
44 |
+
self.face_detector = mp.solutions.face_detection.FaceDetection(
|
45 |
+
min_detection_confidence=0.5
|
46 |
+
)
|
47 |
+
|
48 |
def __download_weights(self, model_dir: str) -> str:
|
49 |
model_path = hf_hub_download(
|
50 |
"natexcvi/pretrained-fecnet",
|
|
|
67 |
def distance(a, b):
|
68 |
return np.linalg.norm(a - b)
|
69 |
|
70 |
+
def embed_image(self, image, crop_face: bool = False) -> np.ndarray:
|
71 |
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
|
72 |
+
if crop_face:
|
73 |
+
image = self.extract_face(image)
|
74 |
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
75 |
image = cv2.resize(image, (224, 224))
|
76 |
image = np.transpose(image, (2, 0, 1))
|
77 |
image = np.expand_dims(image, axis=0)
|
78 |
image = torch.from_numpy(image.astype(np.float32))
|
79 |
pred = self.predict(image)
|
80 |
return pred.detach().numpy()
|
81 |
+
|
82 |
+
def extract_face(self, image):
|
83 |
+
mp_face_detection = mp.solutions.face_detection
|
84 |
+
|
85 |
+
# Convert the image to RGB
|
86 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
87 |
+
|
88 |
+
# Initialize the face detection model
|
89 |
+
|
90 |
+
# Run the face detection model on the image
|
91 |
+
results = self.face_detection.process(image)
|
92 |
+
# If a face is detected, crop the image to the face box
|
93 |
+
if results.detections:
|
94 |
+
for detection in results.detections:
|
95 |
+
x, y, w, h = (
|
96 |
+
int(
|
97 |
+
detection.location_data.relative_bounding_box.xmin
|
98 |
+
* image.shape[1]
|
99 |
+
),
|
100 |
+
int(
|
101 |
+
detection.location_data.relative_bounding_box.ymin
|
102 |
+
* image.shape[0]
|
103 |
+
),
|
104 |
+
int(
|
105 |
+
detection.location_data.relative_bounding_box.width
|
106 |
+
* image.shape[1]
|
107 |
+
),
|
108 |
+
int(
|
109 |
+
detection.location_data.relative_bounding_box.height
|
110 |
+
* image.shape[0]
|
111 |
+
),
|
112 |
+
)
|
113 |
+
cropped_image = image[y : y + h, x : x + w]
|
114 |
+
return cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR)
|
routers/fecnet_router.py
CHANGED
@@ -23,9 +23,10 @@ model = FECNetModel(os.getenv("HF_TOKEN", ""))
|
|
23 |
)
|
24 |
async def calculate_embedding(
|
25 |
image: UploadFile = File(...),
|
|
|
26 |
):
|
27 |
image_arr = np.asarray(bytearray(await image.read()), dtype=np.uint8) # type: ignore
|
28 |
-
rep = model.embed_image(image_arr)
|
29 |
return EmbeddingResponse(embedding=rep.tolist())
|
30 |
|
31 |
|
@@ -37,9 +38,10 @@ async def calculate_embedding(
|
|
37 |
async def calculate_similarity_score(
|
38 |
image1: UploadFile = File(...),
|
39 |
image2: UploadFile = File(...),
|
|
|
40 |
):
|
41 |
image1_arr = np.asarray(bytearray(await image1.read()), dtype=np.uint8) # type: ignore
|
42 |
image2_arr = np.asarray(bytearray(await image2.read()), dtype=np.uint8) # type: ignore
|
43 |
-
rep1 = model.embed_image(image1_arr)
|
44 |
-
rep2 = model.embed_image(image2_arr)
|
45 |
return SimilarityResponse(score=np.linalg.norm(rep1, rep2))
|
|
|
23 |
)
|
24 |
async def calculate_embedding(
|
25 |
image: UploadFile = File(...),
|
26 |
+
should_extract_face: bool = False,
|
27 |
):
|
28 |
image_arr = np.asarray(bytearray(await image.read()), dtype=np.uint8) # type: ignore
|
29 |
+
rep = model.embed_image(image_arr, should_extract_face)
|
30 |
return EmbeddingResponse(embedding=rep.tolist())
|
31 |
|
32 |
|
|
|
38 |
async def calculate_similarity_score(
|
39 |
image1: UploadFile = File(...),
|
40 |
image2: UploadFile = File(...),
|
41 |
+
should_extract_face: bool = False,
|
42 |
):
|
43 |
image1_arr = np.asarray(bytearray(await image1.read()), dtype=np.uint8) # type: ignore
|
44 |
image2_arr = np.asarray(bytearray(await image2.read()), dtype=np.uint8) # type: ignore
|
45 |
+
rep1 = model.embed_image(image1_arr, should_extract_face)
|
46 |
+
rep2 = model.embed_image(image2_arr, should_extract_face)
|
47 |
return SimilarityResponse(score=np.linalg.norm(rep1, rep2))
|