Spaces:
Runtime error
Runtime error
from .face_parsing import BiSeNet | |
import numpy as np | |
from .metrics import LPIPS, MS_SSIM, IdScore, ClipHair | |
import torch.nn as nn | |
import torch | |
from torchvision import transforms | |
class FaceSegmentation(nn.Module): | |
def __init__(self, n_classes=19, device='cuda', save_pth='./pretrained_models/79999_iter.pth'): | |
super(FaceSegmentation, self).__init__() | |
self.net = BiSeNet(n_classes=n_classes).to(device) | |
self.net.load_state_dict(torch.load(save_pth, map_location=torch.device('cpu'))) | |
self.net.eval().to(device) | |
self.transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
]) | |
self.device=device | |
def get_facemask(self, parsing_anno): | |
""" | |
Returns a binary image of the face. | |
""" | |
# face_attr = {1: 'skin', 2: 'l_brow', 3: 'r_brow', 4: 'l_eye', 5: 'r_eye', 6: 'eye_glass', 7: 'l_ear', 8: 'r_ear', 10: 'nose', 11: 'mouth', 12: 'u_lip', 13: 'l_lip', 14: 'neck'} | |
face_attr = torch.tensor([1,2,3,4,5,6,7,8,10,11,12,13,14],device=self.device) | |
face_mask = torch.isin(parsing_anno, face_attr) | |
return(face_mask.int()) | |
def get_hairmask(self, parsing_anno): | |
""" | |
Returns a binary image of the hair. | |
""" | |
hair_mask = parsing_anno == 17 | |
return(hair_mask.int()) | |
def forward(self, img): | |
""" | |
Returns a binary image of the face and hair. | |
""" | |
img = self.transform(img).to(self.device) | |
parsing_anno = self.net(img.unsqueeze(0))[0].squeeze(0).argmax(0) | |
face_mask = self.get_facemask(parsing_anno).to(self.device) | |
hair_mask = self.get_hairmask(parsing_anno).to(self.device) | |
return img, face_mask, hair_mask | |
class FaceMetric(nn.Module): | |
def __init__(self, metric_type, eval_face=True, eval_hair=True, device='cuda', seg_save_pth='./pretrained_models/79999_iter.pth'): | |
super(FaceMetric, self).__init__() | |
if metric_type == 'ms-ssim': | |
self.metric = MS_SSIM() | |
self.eval_hair= eval_hair | |
self.eval_face= eval_face | |
elif metric_type == 'lpips': | |
self.metric = LPIPS(device=device) | |
self.eval_hair= eval_hair | |
self.eval_face= eval_face | |
elif metric_type == 'id': | |
self.metric = IdScore(device=device) | |
self.eval_hair = False | |
self.eval_face = eval_face | |
elif metric_type == 'cliphair': | |
self.metric = ClipHair(device=device) | |
self.eval_face = False | |
self.eval_hair = eval_hair | |
else: | |
raise NotImplementedError | |
self.parser = FaceSegmentation(device=device, save_pth=seg_save_pth) | |
self.device=device | |
def forward(self, x, y): | |
face_score, hair_score = None, None | |
x_tensor, x_face_seg, x_hair_seg = self.parser(x) | |
y_tensor, y_face_seg, y_hair_seg = self.parser(y) | |
if self.eval_hair == True: | |
## Get union of two hair masks | |
#hair_mask = (x_hair_seg + y_hair_seg) > 0 | |
x_hair = x_tensor * x_hair_seg | |
y_hair = y_tensor * y_hair_seg | |
hair_score = self.metric(x_hair, y_hair).item() | |
if self.eval_face == True: | |
## Get intersection of two face masks | |
face_mask = (x_face_seg + y_face_seg) > 1 | |
x_face = x_tensor * face_mask | |
y_face = y_tensor * face_mask | |
face_score = self.metric(x_face, y_face).item() | |
return face_score, hair_score | |