""" Code borrowed from https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49#file-vgg_perceptual_loss-py-L5 """ import torch import torchvision from models.vggface import VGGFaceFeats def cos_loss(fi, ft): return 1 - torch.nn.functional.cosine_similarity(fi, ft).mean() class VGGPerceptualLoss(torch.nn.Module): def __init__(self, resize=False): super(VGGPerceptualLoss, self).__init__() blocks = [] blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) for bl in blocks: for p in bl: p.requires_grad = False self.blocks = torch.nn.ModuleList(blocks) self.transform = torch.nn.functional.interpolate self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) self.resize = resize def forward(self, input, target, max_layer=4, cos_dist: bool = False): target = (target + 1) * 0.5 input = (input + 1) * 0.5 if input.shape[1] != 3: input = input.repeat(1, 3, 1, 1) target = target.repeat(1, 3, 1, 1) input = (input-self.mean) / self.std target = (target-self.mean) / self.std if self.resize: input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) x = input y = target loss = 0.0 loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss for bi, block in enumerate(self.blocks[:max_layer]): x = block(x) y = block(y) loss += loss_func(x, y.detach()) return loss class VGGFacePerceptualLoss(torch.nn.Module): def __init__(self, weight_path: str = "checkpoint/vgg_face_dag.pt", resize: bool = False): super().__init__() self.vgg = VGGFaceFeats() self.vgg.load_state_dict(torch.load(weight_path)) mean = torch.tensor(self.vgg.meta["mean"]).view(1, 3, 1, 1) / 255.0 self.register_buffer("mean", mean) self.transform = torch.nn.functional.interpolate self.resize = resize def forward(self, input, target, max_layer: int = 4, cos_dist: bool = False): target = (target + 1) * 0.5 input = (input + 1) * 0.5 # preprocessing if input.shape[1] != 3: input = input.repeat(1, 3, 1, 1) target = target.repeat(1, 3, 1, 1) input = input - self.mean target = target - self.mean if self.resize: input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) input_feats = self.vgg(input) target_feats = self.vgg(target) loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss # calc perceptual loss loss = 0.0 for fi, ft in zip(input_feats[:max_layer], target_feats[:max_layer]): loss = loss + loss_func(fi, ft.detach()) return loss class PerceptualLoss(torch.nn.Module): def __init__( self, lambda_vggface: float = 0.025 / 0.15, lambda_vgg: float = 1, eps: float = 1e-8, cos_dist: bool = False ): super().__init__() self.register_buffer("lambda_vggface", torch.tensor(lambda_vggface)) self.register_buffer("lambda_vgg", torch.tensor(lambda_vgg)) self.cos_dist = cos_dist if lambda_vgg > eps: self.vgg = VGGPerceptualLoss() if lambda_vggface > eps: self.vggface = VGGFacePerceptualLoss() def forward(self, input, target, eps=1e-8, use_vggface: bool = True, use_vgg=True, max_vgg_layer=4): loss = 0.0 if self.lambda_vgg > eps and use_vgg: loss = loss + self.lambda_vgg * self.vgg(input, target, max_layer=max_vgg_layer) if self.lambda_vggface > eps and use_vggface: loss = loss + self.lambda_vggface * self.vggface(input, target, cos_dist=self.cos_dist) return loss