from .face_parsing import BiSeNet import numpy as np from .metrics import LPIPS, MS_SSIM, IdScore, ClipHair import torch.nn as nn import torch from torchvision import transforms class FaceSegmentation(nn.Module): def __init__(self, n_classes=19, device='cuda', save_pth='./pretrained_models/79999_iter.pth'): super(FaceSegmentation, self).__init__() self.net = BiSeNet(n_classes=n_classes).to(device) self.net.load_state_dict(torch.load(save_pth)) self.net.eval() self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) self.device=device def get_facemask(self, parsing_anno): """ Returns a binary image of the face. """ # face_attr = {1: 'skin', 2: 'l_brow', 3: 'r_brow', 4: 'l_eye', 5: 'r_eye', 6: 'eye_glass', 7: 'l_ear', 8: 'r_ear', 10: 'nose', 11: 'mouth', 12: 'u_lip', 13: 'l_lip', 14: 'neck'} face_attr = torch.tensor([1,2,3,4,5,6,7,8,10,11,12,13,14],device=self.device) face_mask = torch.isin(parsing_anno, face_attr) return(face_mask.int()) def get_hairmask(self, parsing_anno): """ Returns a binary image of the hair. """ hair_mask = parsing_anno == 17 return(hair_mask.int()) def forward(self, img): """ Returns a binary image of the face and hair. """ img = self.transform(img).to(self.device) parsing_anno = self.net(img.unsqueeze(0))[0].squeeze(0).argmax(0) face_mask = self.get_facemask(parsing_anno).to(self.device) hair_mask = self.get_hairmask(parsing_anno).to(self.device) return img, face_mask, hair_mask class FaceMetric(nn.Module): def __init__(self, metric_type, eval_face=True, eval_hair=True, device='cuda', seg_save_pth='./pretrained_models/79999_iter.pth'): super(FaceMetric, self).__init__() if metric_type == 'ms-ssim': self.metric = MS_SSIM() self.eval_hair= eval_hair self.eval_face= eval_face elif metric_type == 'lpips': self.metric = LPIPS(device=device) self.eval_hair= eval_hair self.eval_face= eval_face elif metric_type == 'id': self.metric = IdScore(device=device) self.eval_hair = False self.eval_face = eval_face elif metric_type == 'cliphair': self.metric = ClipHair(device=device) self.eval_face = False self.eval_hair = eval_hair else: raise NotImplementedError self.parser = FaceSegmentation(device=device, save_pth=seg_save_pth) self.device=device def forward(self, x, y): face_score, hair_score = None, None x_tensor, x_face_seg, x_hair_seg = self.parser(x) y_tensor, y_face_seg, y_hair_seg = self.parser(y) if self.eval_hair == True: ## Get union of two hair masks #hair_mask = (x_hair_seg + y_hair_seg) > 0 x_hair = x_tensor * x_hair_seg y_hair = y_tensor * y_hair_seg hair_score = self.metric(x_hair, y_hair).item() if self.eval_face == True: ## Get intersection of two face masks face_mask = (x_face_seg + y_face_seg) > 1 x_face = x_tensor * face_mask y_face = y_tensor * face_mask face_score = self.metric(x_face, y_face).item() return face_score, hair_score