text2svg-demo-app / starvector /metrics /compute_clip_score.py
Jinglong Xiong
add models
6642f4e
raw
history blame contribute delete
2.2 kB
from torchvision.transforms import ToTensor
import torch.nn.functional as F
from starvector.metrics.base_metric import BaseMetric
import torch
from torchmetrics.multimodal.clip_score import CLIPScore
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchvision.transforms as transforms
from torchmetrics.functional.multimodal.clip_score import _clip_score_update
class CLIPScoreCalculator(BaseMetric):
def __init__(self):
super().__init__()
self.class_name = self.__class__.__name__
self.clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch32")
self.clip_score.to('cuda')
def CLIP_Score(self, images, captions):
all_scores = _clip_score_update(images, captions, self.clip_score.model, self.clip_score.processor)
return all_scores
def collate_fn(self, batch):
gen_imgs, captions = zip(*batch)
tensor_gen_imgs = [transforms.ToTensor()(img) for img in gen_imgs]
return tensor_gen_imgs, captions
def calculate_score(self, batch, batch_size=512, update=True):
gen_images = batch['gen_im']
captions = batch['caption']
# Create DataLoader with custom collate function
data_loader = DataLoader(list(zip(gen_images, captions)), collate_fn=self.collate_fn, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
all_scores = []
for batch_eval in tqdm(data_loader):
images, captions = batch_eval
images = [img.to('cuda', non_blocking=True) * 255 for img in images]
list_scores = self.CLIP_Score(images, captions)[0].detach().cpu().tolist()
all_scores.extend(list_scores)
if not all_scores:
print("No valid scores found for metric calculation.")
return float("nan"), []
avg_score = sum(all_scores) / len(all_scores)
if update:
self.meter.update(avg_score, len(all_scores))
return self.meter.avg, all_scores
else:
return avg_score, all_scores
if __name__ == '__main__':
import multiprocessing
multiprocessing.set_start_method('spawn')
# Rest of your code...