|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
import os |
|
import pickle |
|
|
|
import paddle |
|
import paddle.nn as nn |
|
import paddle.nn.functional as F |
|
|
|
|
|
class CenterLoss(nn.Layer): |
|
""" |
|
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. |
|
""" |
|
|
|
def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None): |
|
super().__init__() |
|
self.num_classes = num_classes |
|
self.feat_dim = feat_dim |
|
self.centers = paddle.randn( |
|
shape=[self.num_classes, self.feat_dim]).astype("float64") |
|
|
|
if center_file_path is not None: |
|
assert os.path.exists( |
|
center_file_path |
|
), f"center path({center_file_path}) must exist when it is not None." |
|
with open(center_file_path, 'rb') as f: |
|
char_dict = pickle.load(f) |
|
for key in char_dict.keys(): |
|
self.centers[key] = paddle.to_tensor(char_dict[key]) |
|
|
|
def __call__(self, predicts, batch): |
|
assert isinstance(predicts, (list, tuple)) |
|
features, predicts = predicts |
|
|
|
feats_reshape = paddle.reshape( |
|
features, [-1, features.shape[-1]]).astype("float64") |
|
label = paddle.argmax(predicts, axis=2) |
|
label = paddle.reshape(label, [label.shape[0] * label.shape[1]]) |
|
|
|
batch_size = feats_reshape.shape[0] |
|
|
|
|
|
square_feat = paddle.sum(paddle.square(feats_reshape), |
|
axis=1, |
|
keepdim=True) |
|
square_feat = paddle.expand(square_feat, [batch_size, self.num_classes]) |
|
|
|
square_center = paddle.sum(paddle.square(self.centers), |
|
axis=1, |
|
keepdim=True) |
|
square_center = paddle.expand( |
|
square_center, [self.num_classes, batch_size]).astype("float64") |
|
square_center = paddle.transpose(square_center, [1, 0]) |
|
|
|
distmat = paddle.add(square_feat, square_center) |
|
feat_dot_center = paddle.matmul(feats_reshape, |
|
paddle.transpose(self.centers, [1, 0])) |
|
distmat = distmat - 2.0 * feat_dot_center |
|
|
|
|
|
classes = paddle.arange(self.num_classes).astype("int64") |
|
label = paddle.expand( |
|
paddle.unsqueeze(label, 1), (batch_size, self.num_classes)) |
|
mask = paddle.equal( |
|
paddle.expand(classes, [batch_size, self.num_classes]), |
|
label).astype("float64") |
|
dist = paddle.multiply(distmat, mask) |
|
|
|
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size |
|
return {'loss_center': loss} |
|
|