ucalyptus's picture
simp
2d7efb8
raw
history blame
519 Bytes
import torch
import clip
class CLIPLoss(torch.nn.Module):
def __init__(self, opts):
super(CLIPLoss, self).__init__()
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
self.upsample = torch.nn.Upsample(scale_factor=7)
self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
def forward(self, image, text):
image = self.avg_pool(self.upsample(image))
similarity = 1 - self.model(image, text)[0] / 100
return similarity