import torch from .modules.lpips import LPIPS def lpips(x: torch.Tensor, y: torch.Tensor, net_type: str = 'alex', version: str = '0.1'): r"""Function that measures Learned Perceptual Image Patch Similarity (LPIPS). Arguments: x, y (torch.Tensor): the input tensors to compare. net_type (str): the network type to compare the features: 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. version (str): the version of LPIPS. Default: 0.1. """ device = x.device criterion = LPIPS(net_type, version).to(device) return criterion(x, y)