Spaces:
Runtime error
Runtime error
""" | |
Code borrowed from https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49#file-vgg_perceptual_loss-py-L5 | |
""" | |
import torch | |
import torchvision | |
from models.vggface import VGGFaceFeats | |
def cos_loss(fi, ft): | |
return 1 - torch.nn.functional.cosine_similarity(fi, ft).mean() | |
class VGGPerceptualLoss(torch.nn.Module): | |
def __init__(self, resize=False): | |
super(VGGPerceptualLoss, self).__init__() | |
blocks = [] | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) | |
for bl in blocks: | |
for p in bl: | |
p.requires_grad = False | |
self.blocks = torch.nn.ModuleList(blocks) | |
self.transform = torch.nn.functional.interpolate | |
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) | |
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) | |
self.resize = resize | |
def forward(self, input, target, max_layer=4, cos_dist: bool = False): | |
target = (target + 1) * 0.5 | |
input = (input + 1) * 0.5 | |
if input.shape[1] != 3: | |
input = input.repeat(1, 3, 1, 1) | |
target = target.repeat(1, 3, 1, 1) | |
input = (input-self.mean) / self.std | |
target = (target-self.mean) / self.std | |
if self.resize: | |
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) | |
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) | |
x = input | |
y = target | |
loss = 0.0 | |
loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss | |
for bi, block in enumerate(self.blocks[:max_layer]): | |
x = block(x) | |
y = block(y) | |
loss += loss_func(x, y.detach()) | |
return loss | |
class VGGFacePerceptualLoss(torch.nn.Module): | |
def __init__(self, weight_path: str = "checkpoint/vgg_face_dag.pt", resize: bool = False): | |
super().__init__() | |
self.vgg = VGGFaceFeats() | |
self.vgg.load_state_dict(torch.load(weight_path)) | |
mean = torch.tensor(self.vgg.meta["mean"]).view(1, 3, 1, 1) / 255.0 | |
self.register_buffer("mean", mean) | |
self.transform = torch.nn.functional.interpolate | |
self.resize = resize | |
def forward(self, input, target, max_layer: int = 4, cos_dist: bool = False): | |
target = (target + 1) * 0.5 | |
input = (input + 1) * 0.5 | |
# preprocessing | |
if input.shape[1] != 3: | |
input = input.repeat(1, 3, 1, 1) | |
target = target.repeat(1, 3, 1, 1) | |
input = input - self.mean | |
target = target - self.mean | |
if self.resize: | |
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) | |
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) | |
input_feats = self.vgg(input) | |
target_feats = self.vgg(target) | |
loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss | |
# calc perceptual loss | |
loss = 0.0 | |
for fi, ft in zip(input_feats[:max_layer], target_feats[:max_layer]): | |
loss = loss + loss_func(fi, ft.detach()) | |
return loss | |
class PerceptualLoss(torch.nn.Module): | |
def __init__( | |
self, lambda_vggface: float = 0.025 / 0.15, lambda_vgg: float = 1, eps: float = 1e-8, cos_dist: bool = False | |
): | |
super().__init__() | |
self.register_buffer("lambda_vggface", torch.tensor(lambda_vggface)) | |
self.register_buffer("lambda_vgg", torch.tensor(lambda_vgg)) | |
self.cos_dist = cos_dist | |
if lambda_vgg > eps: | |
self.vgg = VGGPerceptualLoss() | |
if lambda_vggface > eps: | |
self.vggface = VGGFacePerceptualLoss() | |
def forward(self, input, target, eps=1e-8, use_vggface: bool = True, use_vgg=True, max_vgg_layer=4): | |
loss = 0.0 | |
if self.lambda_vgg > eps and use_vgg: | |
loss = loss + self.lambda_vgg * self.vgg(input, target, max_layer=max_vgg_layer) | |
if self.lambda_vggface > eps and use_vggface: | |
loss = loss + self.lambda_vggface * self.vggface(input, target, cos_dist=self.cos_dist) | |
return loss | |