inversion_testing / metrics /face_eval.py
ethanNeuralImage's picture
Adding in metrics
c85e4eb
raw
history blame
No virus
3.58 kB
from .face_parsing import BiSeNet
import numpy as np
from .metrics import LPIPS, MS_SSIM, IdScore, ClipHair
import torch.nn as nn
import torch
from torchvision import transforms
class FaceSegmentation(nn.Module):
def __init__(self, n_classes=19, device='cuda', save_pth='./pretrained_models/79999_iter.pth'):
super(FaceSegmentation, self).__init__()
self.net = BiSeNet(n_classes=n_classes).to(device)
self.net.load_state_dict(torch.load(save_pth))
self.net.eval()
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
self.device=device
def get_facemask(self, parsing_anno):
"""
Returns a binary image of the face.
"""
# face_attr = {1: 'skin', 2: 'l_brow', 3: 'r_brow', 4: 'l_eye', 5: 'r_eye', 6: 'eye_glass', 7: 'l_ear', 8: 'r_ear', 10: 'nose', 11: 'mouth', 12: 'u_lip', 13: 'l_lip', 14: 'neck'}
face_attr = torch.tensor([1,2,3,4,5,6,7,8,10,11,12,13,14],device=self.device)
face_mask = torch.isin(parsing_anno, face_attr)
return(face_mask.int())
def get_hairmask(self, parsing_anno):
"""
Returns a binary image of the hair.
"""
hair_mask = parsing_anno == 17
return(hair_mask.int())
def forward(self, img):
"""
Returns a binary image of the face and hair.
"""
img = self.transform(img).to(self.device)
parsing_anno = self.net(img.unsqueeze(0))[0].squeeze(0).argmax(0)
face_mask = self.get_facemask(parsing_anno).to(self.device)
hair_mask = self.get_hairmask(parsing_anno).to(self.device)
return img, face_mask, hair_mask
class FaceMetric(nn.Module):
def __init__(self, metric_type, eval_face=True, eval_hair=True, device='cuda', seg_save_pth='./pretrained_models/79999_iter.pth'):
super(FaceMetric, self).__init__()
if metric_type == 'ms-ssim':
self.metric = MS_SSIM()
self.eval_hair= eval_hair
self.eval_face= eval_face
elif metric_type == 'lpips':
self.metric = LPIPS(device=device)
self.eval_hair= eval_hair
self.eval_face= eval_face
elif metric_type == 'id':
self.metric = IdScore(device=device)
self.eval_hair = False
self.eval_face = eval_face
elif metric_type == 'cliphair':
self.metric = ClipHair(device=device)
self.eval_face = False
self.eval_hair = eval_hair
else:
raise NotImplementedError
self.parser = FaceSegmentation(device=device, save_pth=seg_save_pth)
self.device=device
def forward(self, x, y):
face_score, hair_score = None, None
x_tensor, x_face_seg, x_hair_seg = self.parser(x)
y_tensor, y_face_seg, y_hair_seg = self.parser(y)
if self.eval_hair == True:
## Get union of two hair masks
#hair_mask = (x_hair_seg + y_hair_seg) > 0
x_hair = x_tensor * x_hair_seg
y_hair = y_tensor * y_hair_seg
hair_score = self.metric(x_hair, y_hair).item()
if self.eval_face == True:
## Get intersection of two face masks
face_mask = (x_face_seg + y_face_seg) > 1
x_face = x_tensor * face_mask
y_face = y_tensor * face_mask
face_score = self.metric(x_face, y_face).item()
return face_score, hair_score