|
import torch |
|
from torch import nn |
|
|
|
from models.facial_recognition.model_irse import Backbone |
|
|
|
|
|
class IDLoss(nn.Module): |
|
def __init__(self, opts): |
|
super(IDLoss, self).__init__() |
|
print('Loading ResNet ArcFace') |
|
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') |
|
self.facenet.load_state_dict(torch.load(opts.ir_se50_weights)) |
|
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) |
|
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) |
|
self.facenet.eval() |
|
self.opts = opts |
|
|
|
def extract_feats(self, x): |
|
if x.shape[2] != 256: |
|
x = self.pool(x) |
|
x = x[:, :, 35:223, 32:220] |
|
x = self.face_pool(x) |
|
x_feats = self.facenet(x) |
|
return x_feats |
|
|
|
def forward(self, y_hat, y): |
|
n_samples = y.shape[0] |
|
y_feats = self.extract_feats(y) |
|
y_hat_feats = self.extract_feats(y_hat) |
|
y_feats = y_feats.detach() |
|
loss = 0 |
|
sim_improvement = 0 |
|
count = 0 |
|
for i in range(n_samples): |
|
diff_target = y_hat_feats[i].dot(y_feats[i]) |
|
loss += 1 - diff_target |
|
count += 1 |
|
|
|
return loss / count, sim_improvement / count |
|
|