natexcvi commited on
Commit
607801a
1 Parent(s): cad80c7

Add support for face extraction

Browse files
Files changed (2) hide show
  1. model/fecnet.py +43 -2
  2. routers/fecnet_router.py +5 -3
model/fecnet.py CHANGED
@@ -5,6 +5,7 @@ from importlib.util import module_from_spec, spec_from_file_location
5
  from tempfile import TemporaryDirectory
6
 
7
  import cv2
 
8
  import numpy as np
9
  import plotly.express as px
10
  import requests
@@ -40,6 +41,10 @@ class FECNetModel:
40
  self.__download_weights(repo_dir.name), fecnet_module.FECNet
41
  )
42
 
 
 
 
 
43
  def __download_weights(self, model_dir: str) -> str:
44
  model_path = hf_hub_download(
45
  "natexcvi/pretrained-fecnet",
@@ -62,12 +67,48 @@ class FECNetModel:
62
  def distance(a, b):
63
  return np.linalg.norm(a - b)
64
 
65
- def embed_image(self, image) -> np.ndarray:
66
  image = cv2.imdecode(image, cv2.IMREAD_COLOR)
67
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
 
68
  image = cv2.resize(image, (224, 224))
69
  image = np.transpose(image, (2, 0, 1))
70
  image = np.expand_dims(image, axis=0)
71
  image = torch.from_numpy(image.astype(np.float32))
72
  pred = self.predict(image)
73
  return pred.detach().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from tempfile import TemporaryDirectory
6
 
7
  import cv2
8
+ import mediapipe as mp
9
  import numpy as np
10
  import plotly.express as px
11
  import requests
 
41
  self.__download_weights(repo_dir.name), fecnet_module.FECNet
42
  )
43
 
44
+ self.face_detector = mp.solutions.face_detection.FaceDetection(
45
+ min_detection_confidence=0.5
46
+ )
47
+
48
  def __download_weights(self, model_dir: str) -> str:
49
  model_path = hf_hub_download(
50
  "natexcvi/pretrained-fecnet",
 
67
  def distance(a, b):
68
  return np.linalg.norm(a - b)
69
 
70
+ def embed_image(self, image, crop_face: bool = False) -> np.ndarray:
71
  image = cv2.imdecode(image, cv2.IMREAD_COLOR)
72
+ if crop_face:
73
+ image = self.extract_face(image)
74
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
75
  image = cv2.resize(image, (224, 224))
76
  image = np.transpose(image, (2, 0, 1))
77
  image = np.expand_dims(image, axis=0)
78
  image = torch.from_numpy(image.astype(np.float32))
79
  pred = self.predict(image)
80
  return pred.detach().numpy()
81
+
82
+ def extract_face(self, image):
83
+ mp_face_detection = mp.solutions.face_detection
84
+
85
+ # Convert the image to RGB
86
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
87
+
88
+ # Initialize the face detection model
89
+
90
+ # Run the face detection model on the image
91
+ results = self.face_detection.process(image)
92
+ # If a face is detected, crop the image to the face box
93
+ if results.detections:
94
+ for detection in results.detections:
95
+ x, y, w, h = (
96
+ int(
97
+ detection.location_data.relative_bounding_box.xmin
98
+ * image.shape[1]
99
+ ),
100
+ int(
101
+ detection.location_data.relative_bounding_box.ymin
102
+ * image.shape[0]
103
+ ),
104
+ int(
105
+ detection.location_data.relative_bounding_box.width
106
+ * image.shape[1]
107
+ ),
108
+ int(
109
+ detection.location_data.relative_bounding_box.height
110
+ * image.shape[0]
111
+ ),
112
+ )
113
+ cropped_image = image[y : y + h, x : x + w]
114
+ return cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR)
routers/fecnet_router.py CHANGED
@@ -23,9 +23,10 @@ model = FECNetModel(os.getenv("HF_TOKEN", ""))
23
  )
24
  async def calculate_embedding(
25
  image: UploadFile = File(...),
 
26
  ):
27
  image_arr = np.asarray(bytearray(await image.read()), dtype=np.uint8) # type: ignore
28
- rep = model.embed_image(image_arr)
29
  return EmbeddingResponse(embedding=rep.tolist())
30
 
31
 
@@ -37,9 +38,10 @@ async def calculate_embedding(
37
  async def calculate_similarity_score(
38
  image1: UploadFile = File(...),
39
  image2: UploadFile = File(...),
 
40
  ):
41
  image1_arr = np.asarray(bytearray(await image1.read()), dtype=np.uint8) # type: ignore
42
  image2_arr = np.asarray(bytearray(await image2.read()), dtype=np.uint8) # type: ignore
43
- rep1 = model.embed_image(image1_arr)
44
- rep2 = model.embed_image(image2_arr)
45
  return SimilarityResponse(score=np.linalg.norm(rep1, rep2))
 
23
  )
24
  async def calculate_embedding(
25
  image: UploadFile = File(...),
26
+ should_extract_face: bool = False,
27
  ):
28
  image_arr = np.asarray(bytearray(await image.read()), dtype=np.uint8) # type: ignore
29
+ rep = model.embed_image(image_arr, should_extract_face)
30
  return EmbeddingResponse(embedding=rep.tolist())
31
 
32
 
 
38
  async def calculate_similarity_score(
39
  image1: UploadFile = File(...),
40
  image2: UploadFile = File(...),
41
+ should_extract_face: bool = False,
42
  ):
43
  image1_arr = np.asarray(bytearray(await image1.read()), dtype=np.uint8) # type: ignore
44
  image2_arr = np.asarray(bytearray(await image2.read()), dtype=np.uint8) # type: ignore
45
+ rep1 = model.embed_image(image1_arr, should_extract_face)
46
+ rep2 = model.embed_image(image2_arr, should_extract_face)
47
  return SimilarityResponse(score=np.linalg.norm(rep1, rep2))