File size: 3,800 Bytes
db86bfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from functools import lru_cache
from typing import List, Tuple

import numpy

from facefusion import inference_manager
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
from facefusion.face_helper import warp_face_by_face_landmark_5
from facefusion.filesystem import resolve_relative_path
from facefusion.thread_helper import conditional_thread_semaphore
from facefusion.typing import Age, DownloadScope, FaceLandmark5, Gender, InferencePool, ModelOptions, ModelSet, Race, VisionFrame


@lru_cache(maxsize = None)
def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
	return\
	{
		'fairface':
		{
			'hashes':
			{
				'face_classifier':
				{
					'url': resolve_download_url('models-3.0.0', 'fairface.hash'),
					'path': resolve_relative_path('../.assets/models/fairface.hash')
				}
			},
			'sources':
			{
				'face_classifier':
				{
					'url': resolve_download_url('models-3.0.0', 'fairface.onnx'),
					'path': resolve_relative_path('../.assets/models/fairface.onnx')
				}
			},
			'template': 'arcface_112_v2',
			'size': (224, 224),
			'mean': [ 0.485, 0.456, 0.406 ],
			'standard_deviation': [ 0.229, 0.224, 0.225 ]
		}
	}


def get_inference_pool() -> InferencePool:
	model_sources = get_model_options().get('sources')
	return inference_manager.get_inference_pool(__name__, model_sources)


def clear_inference_pool() -> None:
	inference_manager.clear_inference_pool(__name__)


def get_model_options() -> ModelOptions:
	return create_static_model_set('full').get('fairface')


def pre_check() -> bool:
	model_hashes = get_model_options().get('hashes')
	model_sources = get_model_options().get('sources')

	return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)


def classify_face(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandmark5) -> Tuple[Gender, Age, Race]:
	model_template = get_model_options().get('template')
	model_size = get_model_options().get('size')
	model_mean = get_model_options().get('mean')
	model_standard_deviation = get_model_options().get('standard_deviation')
	crop_vision_frame, _ = warp_face_by_face_landmark_5(temp_vision_frame, face_landmark_5, model_template, model_size)
	crop_vision_frame = crop_vision_frame.astype(numpy.float32)[:, :, ::-1] / 255
	crop_vision_frame -= model_mean
	crop_vision_frame /= model_standard_deviation
	crop_vision_frame = crop_vision_frame.transpose(2, 0, 1)
	crop_vision_frame = numpy.expand_dims(crop_vision_frame, axis = 0)
	gender_id, age_id, race_id = forward(crop_vision_frame)
	gender = categorize_gender(gender_id[0])
	age = categorize_age(age_id[0])
	race = categorize_race(race_id[0])
	return gender, age, race


def forward(crop_vision_frame : VisionFrame) -> Tuple[List[int], List[int], List[int]]:
	face_classifier = get_inference_pool().get('face_classifier')

	with conditional_thread_semaphore():
		race_id, gender_id, age_id = face_classifier.run(None,
		{
			'input': crop_vision_frame
		})

	return gender_id, age_id, race_id


def categorize_gender(gender_id : int) -> Gender:
	if gender_id == 1:
		return 'female'
	return 'male'


def categorize_age(age_id : int) -> Age:
	if age_id == 0:
		return range(0, 2)
	if age_id == 1:
		return range(3, 9)
	if age_id == 2:
		return range(10, 19)
	if age_id == 3:
		return range(20, 29)
	if age_id == 4:
		return range(30, 39)
	if age_id == 5:
		return range(40, 49)
	if age_id == 6:
		return range(50, 59)
	if age_id == 7:
		return range(60, 69)
	return range(70, 100)


def categorize_race(race_id : int) -> Race:
	if race_id == 1:
		return 'black'
	if race_id == 2:
		return 'latino'
	if race_id == 3 or race_id == 4:
		return 'asian'
	if race_id == 5:
		return 'indian'
	if race_id == 6:
		return 'arabic'
	return 'white'