|
import tensorflow as tf |
|
from tensorflow.keras import Model |
|
from tensorflow.keras.applications import MobileNetV2, ResNet50 |
|
from tensorflow.keras.layers import Input, Conv2D, ReLU, LeakyReLU |
|
from retinaface.anchor import decode_tf, prior_box_tf |
|
|
|
|
|
def _regularizer(weights_decay): |
|
"""l2 regularizer""" |
|
return tf.keras.regularizers.l2(weights_decay) |
|
|
|
|
|
def _kernel_init(scale=1.0, seed=None): |
|
"""He normal initializer""" |
|
return tf.keras.initializers.he_normal() |
|
|
|
|
|
class BatchNormalization(tf.keras.layers.BatchNormalization): |
|
"""Make trainable=False freeze BN for real (the og version is sad). |
|
ref: https://github.com/zzh8829/yolov3-tf2 |
|
""" |
|
def __init__(self, axis=-1, momentum=0.9, epsilon=1e-5, center=True, |
|
scale=True, name=None, **kwargs): |
|
super(BatchNormalization, self).__init__( |
|
axis=axis, momentum=momentum, epsilon=epsilon, center=center, |
|
scale=scale, name=name, **kwargs) |
|
|
|
def call(self, x, training=False): |
|
if training is None: |
|
training = tf.constant(False) |
|
training = tf.logical_and(training, self.trainable) |
|
|
|
return super().call(x, training) |
|
|
|
|
|
def Backbone(backbone_type='ResNet50', use_pretrain=True): |
|
"""Backbone Model""" |
|
weights = None |
|
if use_pretrain: |
|
weights = 'imagenet' |
|
|
|
def backbone(x): |
|
if backbone_type == 'ResNet50': |
|
extractor = ResNet50( |
|
input_shape=x.shape[1:], include_top=False, weights=weights) |
|
pick_layer1 = 80 |
|
pick_layer2 = 142 |
|
pick_layer3 = 174 |
|
preprocess = tf.keras.applications.resnet.preprocess_input |
|
elif backbone_type == 'MobileNetV2': |
|
extractor = MobileNetV2( |
|
input_shape=x.shape[1:], include_top=False, weights=weights) |
|
pick_layer1 = 54 |
|
pick_layer2 = 116 |
|
pick_layer3 = 143 |
|
preprocess = tf.keras.applications.mobilenet_v2.preprocess_input |
|
else: |
|
raise NotImplementedError( |
|
'Backbone type {} is not recognized.'.format(backbone_type)) |
|
|
|
return Model(extractor.input, |
|
(extractor.layers[pick_layer1].output, |
|
extractor.layers[pick_layer2].output, |
|
extractor.layers[pick_layer3].output), |
|
name=backbone_type + '_extrator')(preprocess(x)) |
|
|
|
return backbone |
|
|
|
|
|
class ConvUnit(tf.keras.layers.Layer): |
|
"""Conv + BN + Act""" |
|
def __init__(self, f, k, s, wd, act=None, **kwargs): |
|
super(ConvUnit, self).__init__(**kwargs) |
|
self.conv = Conv2D(filters=f, kernel_size=k, strides=s, padding='same', |
|
kernel_initializer=_kernel_init(), |
|
kernel_regularizer=_regularizer(wd), |
|
use_bias=False) |
|
self.bn = BatchNormalization() |
|
|
|
if act is None: |
|
self.act_fn = tf.identity |
|
elif act == 'relu': |
|
self.act_fn = ReLU() |
|
elif act == 'lrelu': |
|
self.act_fn = LeakyReLU(0.1) |
|
else: |
|
raise NotImplementedError( |
|
'Activation function type {} is not recognized.'.format(act)) |
|
|
|
def call(self, x): |
|
return self.act_fn(self.bn(self.conv(x))) |
|
|
|
|
|
class FPN(tf.keras.layers.Layer): |
|
"""Feature Pyramid Network""" |
|
def __init__(self, out_ch, wd, **kwargs): |
|
super(FPN, self).__init__(**kwargs) |
|
act = 'relu' |
|
self.out_ch = out_ch |
|
self.wd = wd |
|
if (out_ch <= 64): |
|
act = 'lrelu' |
|
|
|
self.output1 = ConvUnit(f=out_ch, k=1, s=1, wd=wd, act=act) |
|
self.output2 = ConvUnit(f=out_ch, k=1, s=1, wd=wd, act=act) |
|
self.output3 = ConvUnit(f=out_ch, k=1, s=1, wd=wd, act=act) |
|
self.merge1 = ConvUnit(f=out_ch, k=3, s=1, wd=wd, act=act) |
|
self.merge2 = ConvUnit(f=out_ch, k=3, s=1, wd=wd, act=act) |
|
|
|
def call(self, x): |
|
output1 = self.output1(x[0]) |
|
output2 = self.output2(x[1]) |
|
output3 = self.output3(x[2]) |
|
|
|
up_h, up_w = tf.shape(output2)[1], tf.shape(output2)[2] |
|
up3 = tf.image.resize(output3, [up_h, up_w], method='nearest') |
|
output2 = output2 + up3 |
|
output2 = self.merge2(output2) |
|
|
|
up_h, up_w = tf.shape(output1)[1], tf.shape(output1)[2] |
|
up2 = tf.image.resize(output2, [up_h, up_w], method='nearest') |
|
output1 = output1 + up2 |
|
output1 = self.merge1(output1) |
|
|
|
return output1, output2, output3 |
|
|
|
def get_config(self): |
|
config = { |
|
'out_ch': self.out_ch, |
|
'wd': self.wd, |
|
} |
|
base_config = super(FPN, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
class SSH(tf.keras.layers.Layer): |
|
"""Single Stage Headless Layer""" |
|
def __init__(self, out_ch, wd, **kwargs): |
|
super(SSH, self).__init__(**kwargs) |
|
assert out_ch % 4 == 0 |
|
self.out_ch = out_ch |
|
self.wd = wd |
|
act = 'relu' |
|
if (out_ch <= 64): |
|
act = 'lrelu' |
|
|
|
self.conv_3x3 = ConvUnit(f=out_ch // 2, k=3, s=1, wd=wd, act=None) |
|
|
|
self.conv_5x5_1 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=act) |
|
self.conv_5x5_2 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=None) |
|
|
|
self.conv_7x7_2 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=act) |
|
self.conv_7x7_3 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=None) |
|
|
|
self.relu = ReLU() |
|
|
|
def call(self, x): |
|
conv_3x3 = self.conv_3x3(x) |
|
|
|
conv_5x5_1 = self.conv_5x5_1(x) |
|
conv_5x5 = self.conv_5x5_2(conv_5x5_1) |
|
|
|
conv_7x7_2 = self.conv_7x7_2(conv_5x5_1) |
|
conv_7x7 = self.conv_7x7_3(conv_7x7_2) |
|
|
|
output = tf.concat([conv_3x3, conv_5x5, conv_7x7], axis=3) |
|
output = self.relu(output) |
|
|
|
return output |
|
|
|
def get_config(self): |
|
config = { |
|
'out_ch': self.out_ch, |
|
'wd': self.wd, |
|
} |
|
base_config = super(SSH, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
class BboxHead(tf.keras.layers.Layer): |
|
"""Bbox Head Layer""" |
|
def __init__(self, num_anchor, wd, **kwargs): |
|
super(BboxHead, self).__init__(**kwargs) |
|
self.num_anchor = num_anchor |
|
self.wd = wd |
|
self.conv = Conv2D(filters=num_anchor * 4, kernel_size=1, strides=1) |
|
|
|
def call(self, x): |
|
h, w = tf.shape(x)[1], tf.shape(x)[2] |
|
x = self.conv(x) |
|
|
|
return tf.reshape(x, [-1, h * w * self.num_anchor, 4]) |
|
|
|
def get_config(self): |
|
config = { |
|
'num_anchor': self.num_anchor, |
|
'wd': self.wd, |
|
} |
|
base_config = super(BboxHead, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
class LandmarkHead(tf.keras.layers.Layer): |
|
"""Landmark Head Layer""" |
|
def __init__(self, num_anchor, wd, name='LandmarkHead', **kwargs): |
|
super(LandmarkHead, self).__init__(name=name, **kwargs) |
|
self.num_anchor = num_anchor |
|
self.wd = wd |
|
self.conv = Conv2D(filters=num_anchor * 10, kernel_size=1, strides=1) |
|
|
|
def call(self, x): |
|
h, w = tf.shape(x)[1], tf.shape(x)[2] |
|
x = self.conv(x) |
|
|
|
return tf.reshape(x, [-1, h * w * self.num_anchor, 10]) |
|
|
|
def get_config(self): |
|
config = { |
|
'num_anchor': self.num_anchor, |
|
'wd': self.wd, |
|
} |
|
base_config = super(LandmarkHead, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
class ClassHead(tf.keras.layers.Layer): |
|
"""Class Head Layer""" |
|
def __init__(self, num_anchor, wd, name='ClassHead', **kwargs): |
|
super(ClassHead, self).__init__(name=name, **kwargs) |
|
self.num_anchor = num_anchor |
|
self.wd = wd |
|
self.conv = Conv2D(filters=num_anchor * 2, kernel_size=1, strides=1) |
|
|
|
def call(self, x): |
|
h, w = tf.shape(x)[1], tf.shape(x)[2] |
|
x = self.conv(x) |
|
|
|
return tf.reshape(x, [-1, h * w * self.num_anchor, 2]) |
|
|
|
def get_config(self): |
|
config = { |
|
'num_anchor': self.num_anchor, |
|
'wd': self.wd, |
|
} |
|
base_config = super(ClassHead, self).get_config() |
|
return dict(list(base_config.items()) + list(config.items())) |
|
|
|
|
|
def RetinaFaceModel(cfg, training=False, iou_th=0.4, score_th=0.02, |
|
name='RetinaFaceModel'): |
|
"""Retina Face Model""" |
|
input_size = cfg['input_size'] if training else None |
|
wd = cfg['weights_decay'] |
|
out_ch = cfg['out_channel'] |
|
num_anchor = len(cfg['min_sizes'][0]) |
|
backbone_type = cfg['backbone_type'] |
|
|
|
|
|
x = inputs = Input([input_size, input_size, 3], name='input_image') |
|
|
|
x = Backbone(backbone_type=backbone_type)(x) |
|
|
|
fpn = FPN(out_ch=out_ch, wd=wd)(x) |
|
|
|
features = [SSH(out_ch=out_ch, wd=wd)(f) |
|
for i, f in enumerate(fpn)] |
|
|
|
bbox_regressions = tf.concat( |
|
[BboxHead(num_anchor, wd=wd)(f) |
|
for i, f in enumerate(features)], axis=1) |
|
landm_regressions = tf.concat( |
|
[LandmarkHead(num_anchor, wd=wd, name=f'LandmarkHead_{i}')(f) |
|
for i, f in enumerate(features)], axis=1) |
|
classifications = tf.concat( |
|
[ClassHead(num_anchor, wd=wd, name=f'ClassHead_{i}')(f) |
|
for i, f in enumerate(features)], axis=1) |
|
|
|
classifications = tf.keras.layers.Softmax(axis=-1)(classifications) |
|
|
|
if training: |
|
out = (bbox_regressions, landm_regressions, classifications) |
|
else: |
|
|
|
preds = tf.concat( |
|
[bbox_regressions[0], |
|
landm_regressions[0], |
|
tf.ones_like(classifications[0, :, 0][..., tf.newaxis]), |
|
classifications[0, :, 1][..., tf.newaxis]], 1) |
|
priors = prior_box_tf((tf.shape(inputs)[1], tf.shape(inputs)[2]), cfg['min_sizes'], cfg['steps'], cfg['clip']) |
|
decode_preds = decode_tf(preds, priors, cfg['variances']) |
|
|
|
selected_indices = tf.image.non_max_suppression( |
|
boxes=decode_preds[:, :4], |
|
scores=decode_preds[:, -1], |
|
max_output_size=tf.shape(decode_preds)[0], |
|
iou_threshold=iou_th, |
|
score_threshold=score_th) |
|
|
|
out = tf.gather(decode_preds, selected_indices) |
|
|
|
return Model(inputs, out, name=name), Model(inputs, [bbox_regressions, landm_regressions, classifications], name=name + '_bb_only') |
|
|