paresh95 commited on
Commit
008760f
1 Parent(s): 9bff0ef

PS | Change age and gender models to VIT

Browse files
data/4_6_boy.jpg ADDED
notebooks/facial_age_gender.ipynb CHANGED
@@ -22,7 +22,7 @@
22
  {
23
  "data": {
24
  "text/plain": [
25
- "'/Users/pareshar/Personal/Github/Facial-feature-detector'"
26
  ]
27
  },
28
  "execution_count": 2,
@@ -308,6 +308,206 @@
308
  "df.sort_values(\"file_name\")"
309
  ]
310
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  {
312
  "cell_type": "markdown",
313
  "metadata": {},
@@ -315,6 +515,11 @@
315
  "# Other\n",
316
  "- Dataset used to train model: https://talhassner.github.io/home/projects/Adience/Adience-data.html#agegender"
317
  ]
 
 
 
 
 
318
  }
319
  ],
320
  "metadata": {
 
22
  {
23
  "data": {
24
  "text/plain": [
25
+ "'/Users/pareshar/Personal/Github/temp/Facial-feature-detector'"
26
  ]
27
  },
28
  "execution_count": 2,
 
308
  "df.sort_values(\"file_name\")"
309
  ]
310
  },
311
+ {
312
+ "cell_type": "markdown",
313
+ "metadata": {},
314
+ "source": [
315
+ "# Hugging face pre-trained VIT model"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": 5,
321
+ "metadata": {},
322
+ "outputs": [
323
+ {
324
+ "name": "stderr",
325
+ "output_type": "stream",
326
+ "text": [
327
+ "/Users/pareshar/.pyenv/versions/3.8.10/lib/python3.8/site-packages/requests/__init__.py:102: RequestsDependencyWarning: urllib3 (1.26.15) or chardet (5.1.0)/charset_normalizer (2.0.12) doesn't match a supported version!\n",
328
+ " warnings.warn(\"urllib3 ({}) or chardet ({})/charset_normalizer ({}) doesn't match a supported \"\n",
329
+ "/Users/pareshar/.pyenv/versions/3.8.10/lib/python3.8/site-packages/urllib3/connectionpool.py:1045: InsecureRequestWarning: Unverified HTTPS request is being made to host 'huggingface.co'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html#ssl-warnings\n",
330
+ " warnings.warn(\n",
331
+ "/Users/pareshar/.pyenv/versions/3.8.10/lib/python3.8/site-packages/urllib3/connectionpool.py:1045: InsecureRequestWarning: Unverified HTTPS request is being made to host 'huggingface.co'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html#ssl-warnings\n",
332
+ " warnings.warn(\n",
333
+ "/Users/pareshar/.pyenv/versions/3.8.10/lib/python3.8/site-packages/transformers/models/vit/feature_extraction_vit.py:28: FutureWarning: The class ViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use ViTImageProcessor instead.\n",
334
+ " warnings.warn(\n"
335
+ ]
336
+ }
337
+ ],
338
+ "source": [
339
+ "# age\n",
340
+ "\n",
341
+ "import os\n",
342
+ "import cv2\n",
343
+ "from transformers import ViTImageProcessor, ViTForImageClassification\n",
344
+ "\n",
345
+ "os.environ[\n",
346
+ " \"CURL_CA_BUNDLE\"\n",
347
+ " ] = \"\" # fixes VPN issue when connecting to hugging face hub\n",
348
+ "\n",
349
+ "\n",
350
+ "image = cv2.imread(\"data/4_6_boy.jpg\")\n",
351
+ "\n",
352
+ "\n",
353
+ "# Init model, transforms\n",
354
+ "model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')\n",
355
+ "transforms = ViTImageProcessor.from_pretrained('nateraw/vit-age-classifier')\n",
356
+ "\n",
357
+ "# Transform our image and pass it through the model\n",
358
+ "inputs = transforms(image, return_tensors='pt')\n",
359
+ "output = model(**inputs)\n",
360
+ "\n",
361
+ "# Predicted Class probabilities\n",
362
+ "proba = output.logits.softmax(1)\n",
363
+ "\n",
364
+ "# Predicted Classes\n",
365
+ "preds = proba.argmax(1)\n"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": 24,
371
+ "metadata": {},
372
+ "outputs": [
373
+ {
374
+ "data": {
375
+ "text/plain": [
376
+ "0.7176125645637512"
377
+ ]
378
+ },
379
+ "execution_count": 24,
380
+ "metadata": {},
381
+ "output_type": "execute_result"
382
+ }
383
+ ],
384
+ "source": [
385
+ "max(proba[0]).item()"
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": 7,
391
+ "metadata": {},
392
+ "outputs": [
393
+ {
394
+ "data": {
395
+ "text/plain": [
396
+ "'3-9'"
397
+ ]
398
+ },
399
+ "execution_count": 7,
400
+ "metadata": {},
401
+ "output_type": "execute_result"
402
+ }
403
+ ],
404
+ "source": [
405
+ "id2label = {\n",
406
+ " 0: \"0-2\",\n",
407
+ " 1: \"3-9\",\n",
408
+ " 2: \"10-19\",\n",
409
+ " 3: \"20-29\",\n",
410
+ " 4: \"30-39\",\n",
411
+ " 5: \"40-49\",\n",
412
+ " 6: \"50-59\",\n",
413
+ " 7: \"60-69\",\n",
414
+ " 8: \"more than 70\"\n",
415
+ " }\n",
416
+ "\n",
417
+ "id2label[int(preds)]"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": 28,
423
+ "metadata": {},
424
+ "outputs": [
425
+ {
426
+ "name": "stderr",
427
+ "output_type": "stream",
428
+ "text": [
429
+ "/Users/pareshar/.pyenv/versions/3.8.10/lib/python3.8/site-packages/urllib3/connectionpool.py:1045: InsecureRequestWarning: Unverified HTTPS request is being made to host 'huggingface.co'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html#ssl-warnings\n",
430
+ " warnings.warn(\n",
431
+ "/Users/pareshar/.pyenv/versions/3.8.10/lib/python3.8/site-packages/urllib3/connectionpool.py:1045: InsecureRequestWarning: Unverified HTTPS request is being made to host 'huggingface.co'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html#ssl-warnings\n",
432
+ " warnings.warn(\n"
433
+ ]
434
+ }
435
+ ],
436
+ "source": [
437
+ "# gender\n",
438
+ "\n",
439
+ "import os\n",
440
+ "import cv2\n",
441
+ "from transformers import ViTImageProcessor, ViTForImageClassification\n",
442
+ "\n",
443
+ "os.environ[\n",
444
+ " \"CURL_CA_BUNDLE\"\n",
445
+ " ] = \"\" # fixes VPN issue when connecting to hugging face hub\n",
446
+ "\n",
447
+ "\n",
448
+ "image = cv2.imread(\"data/gigi_hadid.webp\")\n",
449
+ "\n",
450
+ "\n",
451
+ "# Init model, transforms\n",
452
+ "model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification')\n",
453
+ "transforms = ViTImageProcessor.from_pretrained('rizvandwiki/gender-classification')\n",
454
+ "\n",
455
+ "# Transform our image and pass it through the model\n",
456
+ "inputs = transforms(image, return_tensors='pt')\n",
457
+ "output = model(**inputs)\n",
458
+ "\n",
459
+ "# Predicted Class probabilities\n",
460
+ "proba = output.logits.softmax(1)\n",
461
+ "\n",
462
+ "# Predicted Classes\n",
463
+ "preds = proba.argmax(1)\n"
464
+ ]
465
+ },
466
+ {
467
+ "cell_type": "code",
468
+ "execution_count": 29,
469
+ "metadata": {},
470
+ "outputs": [
471
+ {
472
+ "data": {
473
+ "text/plain": [
474
+ "0.9677436351776123"
475
+ ]
476
+ },
477
+ "execution_count": 29,
478
+ "metadata": {},
479
+ "output_type": "execute_result"
480
+ }
481
+ ],
482
+ "source": [
483
+ "max(proba[0]).item()"
484
+ ]
485
+ },
486
+ {
487
+ "cell_type": "code",
488
+ "execution_count": 30,
489
+ "metadata": {},
490
+ "outputs": [
491
+ {
492
+ "data": {
493
+ "text/plain": [
494
+ "'female'"
495
+ ]
496
+ },
497
+ "execution_count": 30,
498
+ "metadata": {},
499
+ "output_type": "execute_result"
500
+ }
501
+ ],
502
+ "source": [
503
+ "id2label = {\n",
504
+ " 0: \"female\",\n",
505
+ " 1: \"male\",\n",
506
+ " }\n",
507
+ "\n",
508
+ "id2label[int(preds)]"
509
+ ]
510
+ },
511
  {
512
  "cell_type": "markdown",
513
  "metadata": {},
 
515
  "# Other\n",
516
  "- Dataset used to train model: https://talhassner.github.io/home/projects/Adience/Adience-data.html#agegender"
517
  ]
518
+ },
519
+ {
520
+ "cell_type": "markdown",
521
+ "metadata": {},
522
+ "source": []
523
  }
524
  ],
525
  "metadata": {
requirements.txt CHANGED
@@ -6,3 +6,4 @@ imutils==0.5.4
6
  pillow==9.4.0
7
  pyyaml==6.0
8
  scikit-learn==1.2.2
 
 
6
  pillow==9.4.0
7
  pyyaml==6.0
8
  scikit-learn==1.2.2
9
+ transfomers==4.28.1
src/face_demographics.py CHANGED
@@ -4,6 +4,8 @@ import numpy as np
4
  import os
5
  from typing import Tuple
6
  from src.cv_utils import get_image
 
 
7
 
8
 
9
  with open("parameters.yml", "r") as stream:
@@ -18,7 +20,13 @@ class GetFaceDemographics:
18
  pass
19
 
20
  @staticmethod
21
- def get_age(blob) -> Tuple:
 
 
 
 
 
 
22
  age_net = cv2.dnn.readNet(parameters["face_age"]["config"], parameters["face_age"]["model"])
23
  age_list = ['(0-2)', '(4-6)', '(8-12)', '(15-20)', '(25-32)', '(38-43)', '(48-53)', '(60-100)']
24
  age_net.setInput(blob)
@@ -29,7 +37,7 @@ class GetFaceDemographics:
29
  return age, age_confidence_score
30
 
31
  @staticmethod
32
- def get_gender(blob) -> Tuple:
33
  gender_net = cv2.dnn.readNet(parameters["face_gender"]["config"], parameters["face_gender"]["model"])
34
  gender_list = ['Male', 'Female']
35
  gender_net.setInput(blob)
@@ -38,13 +46,54 @@ class GetFaceDemographics:
38
  gender = gender_list[i]
39
  gender_confidence_score = gender_preds[0][i]
40
  return gender, gender_confidence_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def main(self, image_input) -> dict:
43
  image = get_image(image_input)
44
- model_mean = (78.4263377603, 87.7689143744, 114.895847746) # taken from the model page on Caffe
45
- blob = cv2.dnn.blobFromImage(image, 1.0, (227, 227), model_mean, swapRB=False)
46
- age, age_confidence_score = self.get_age(blob)
47
- gender, gender_confidence_score = self.get_gender(blob)
48
  d = {
49
  "age_range": age,
50
  "age_confidence": age_confidence_score,
@@ -53,7 +102,6 @@ class GetFaceDemographics:
53
  }
54
  return d
55
 
56
-
57
  if __name__ == "__main__":
58
  path_to_images = "data/"
59
  image_files = os.listdir(path_to_images)
 
4
  import os
5
  from typing import Tuple
6
  from src.cv_utils import get_image
7
+ from transformers import ViTImageProcessor, ViTForImageClassification
8
+ import urllib3
9
 
10
 
11
  with open("parameters.yml", "r") as stream:
 
20
  pass
21
 
22
  @staticmethod
23
+ def preprocess_image_for_caffe_cnn(image: np.array):
24
+ model_mean = (78.4263377603, 87.7689143744, 114.895847746) # taken from the model page on Caffe
25
+ blob = cv2.dnn.blobFromImage(image, 1.0, (227, 227), model_mean, swapRB=False)
26
+ return blob
27
+
28
+ @staticmethod
29
+ def get_age_cnn(blob) -> Tuple:
30
  age_net = cv2.dnn.readNet(parameters["face_age"]["config"], parameters["face_age"]["model"])
31
  age_list = ['(0-2)', '(4-6)', '(8-12)', '(15-20)', '(25-32)', '(38-43)', '(48-53)', '(60-100)']
32
  age_net.setInput(blob)
 
37
  return age, age_confidence_score
38
 
39
  @staticmethod
40
+ def get_gender_cnn(blob) -> Tuple:
41
  gender_net = cv2.dnn.readNet(parameters["face_gender"]["config"], parameters["face_gender"]["model"])
42
  gender_list = ['Male', 'Female']
43
  gender_net.setInput(blob)
 
46
  gender = gender_list[i]
47
  gender_confidence_score = gender_preds[0][i]
48
  return gender, gender_confidence_score
49
+
50
+ @staticmethod
51
+ def get_age_vit(image: np.array) -> Tuple:
52
+ os.environ["CURL_CA_BUNDLE"] = "" # fixes VPN issue when connecting to hugging face hub
53
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
54
+ id2label = {
55
+ 0: "0-2",
56
+ 1: "3-9",
57
+ 2: "10-19",
58
+ 3: "20-29",
59
+ 4: "30-39",
60
+ 5: "40-49",
61
+ 6: "50-59",
62
+ 7: "60-69",
63
+ 8: "more than 70"
64
+ }
65
+ model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
66
+ transforms = ViTImageProcessor.from_pretrained('nateraw/vit-age-classifier')
67
+ inputs = transforms(image, return_tensors='pt')
68
+ output = model(**inputs)
69
+ proba = output.logits.softmax(1)
70
+ preds = proba.argmax(1)
71
+ age_confidence_score = max(proba[0]).item()
72
+ age = id2label[int(preds)]
73
+ return age, age_confidence_score
74
+
75
+ @staticmethod
76
+ def get_gender_vit(image: np.array) -> Tuple:
77
+ os.environ["CURL_CA_BUNDLE"] = "" # fixes VPN issue when connecting to hugging face hub
78
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
79
+ id2label = {
80
+ 0: "female",
81
+ 1: "male",
82
+ }
83
+ model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification')
84
+ transforms = ViTImageProcessor.from_pretrained('rizvandwiki/gender-classification')
85
+ inputs = transforms(image, return_tensors='pt')
86
+ output = model(**inputs)
87
+ proba = output.logits.softmax(1)
88
+ preds = proba.argmax(1)
89
+ gender_confidence_score = max(proba[0]).item()
90
+ gender = id2label[int(preds)]
91
+ return gender, gender_confidence_score
92
 
93
  def main(self, image_input) -> dict:
94
  image = get_image(image_input)
95
+ age, age_confidence_score = self.get_age_vit(image)
96
+ gender, gender_confidence_score = self.get_gender_vit(image)
 
 
97
  d = {
98
  "age_range": age,
99
  "age_confidence": age_confidence_score,
 
102
  }
103
  return d
104
 
 
105
  if __name__ == "__main__":
106
  path_to_images = "data/"
107
  image_files = os.listdir(path_to_images)