from functools import partial from typing import Callable, Dict, List import numpy as np import torch from torchmetrics.functional.multimodal import clip_score from torchmetrics.image.inception import InceptionScore SEED = 0 inception_score_fn = InceptionScore(normalize=True) torch.manual_seed(SEED) clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16") def compute_main_metrics(images: np.ndarray, prompts: List[str]) -> Dict: inception_score_fn.update(torch.from_numpy(images).permute(0, 3, 1, 2)) inception_score = inception_score_fn.compute() images_int = (images * 255).astype("uint8") clip_score = clip_score_fn( torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts ).detach() return { "inception_score (⬆️)": { "mean": round(float(inception_score[0]), 4), "std": round(float(inception_score[1]), 4), }, "clip_score (⬆️)": round(float(clip_score), 4), } def compute_psnr_or_ssim( fn: Callable, images_dict: Dict, original_scheduler_name: str ) -> Dict: result_dict = {} original_scheduler_images = images_dict[original_scheduler_name] original_scheduler_images = torch.from_numpy(original_scheduler_images).permute( 0, 3, 1, 2 ) for k in images_dict: if k != original_scheduler_name: current_scheduler_images = torch.from_numpy(images_dict[k]).permute( 0, 3, 1, 2 ) current_value = fn(current_scheduler_images, original_scheduler_images) result_dict.update({k: round(float(current_value), 4)}) return result_dict