import torch | |
import clip | |
class CLIPLoss(torch.nn.Module): | |
def __init__(self, opts): | |
super(CLIPLoss, self).__init__() | |
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") | |
self.upsample = torch.nn.Upsample(scale_factor=7) | |
self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32) | |
def forward(self, image, text): | |
image = self.avg_pool(self.upsample(image)) | |
similarity = 1 - self.model(image, text)[0] / 100 | |
return similarity |