Spaces:
Paused
Paused
from torchvision.transforms import ToTensor | |
import torch.nn.functional as F | |
from starvector.metrics.base_metric import BaseMetric | |
import torch | |
from torchmetrics.multimodal.clip_score import CLIPScore | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
import torchvision.transforms as transforms | |
from torchmetrics.functional.multimodal.clip_score import _clip_score_update | |
class CLIPScoreCalculator(BaseMetric): | |
def __init__(self): | |
super().__init__() | |
self.class_name = self.__class__.__name__ | |
self.clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch32") | |
self.clip_score.to('cuda') | |
def CLIP_Score(self, images, captions): | |
all_scores = _clip_score_update(images, captions, self.clip_score.model, self.clip_score.processor) | |
return all_scores | |
def collate_fn(self, batch): | |
gen_imgs, captions = zip(*batch) | |
tensor_gen_imgs = [transforms.ToTensor()(img) for img in gen_imgs] | |
return tensor_gen_imgs, captions | |
def calculate_score(self, batch, batch_size=512, update=True): | |
gen_images = batch['gen_im'] | |
captions = batch['caption'] | |
# Create DataLoader with custom collate function | |
data_loader = DataLoader(list(zip(gen_images, captions)), collate_fn=self.collate_fn, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) | |
all_scores = [] | |
for batch_eval in tqdm(data_loader): | |
images, captions = batch_eval | |
images = [img.to('cuda', non_blocking=True) * 255 for img in images] | |
list_scores = self.CLIP_Score(images, captions)[0].detach().cpu().tolist() | |
all_scores.extend(list_scores) | |
if not all_scores: | |
print("No valid scores found for metric calculation.") | |
return float("nan"), [] | |
avg_score = sum(all_scores) / len(all_scores) | |
if update: | |
self.meter.update(avg_score, len(all_scores)) | |
return self.meter.avg, all_scores | |
else: | |
return avg_score, all_scores | |
if __name__ == '__main__': | |
import multiprocessing | |
multiprocessing.set_start_method('spawn') | |
# Rest of your code... |