import pickle import torch import torchvision from pathlib import Path from dp2 import utils import tops try: import clip except ImportError: print("Could not import clip.") from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric clip_model = None clip_preprocess = None @torch.no_grad() def compute_fid_clip( dataloader, generator, cache_directory, data_len=None, **kwargs ) -> dict: """ FID CLIP following the description in The Role of ImageNet Classes in Frechet Inception Distance, Thomas Kynkaamniemi et al. Args: n_samples (int): Creates N samples from same image to calculate stats """ global clip_model, clip_preprocess if clip_model is None: clip_model, preprocess = clip.load("ViT-B/32", device="cpu") normalize_fn = preprocess.transforms[-1] img_mean = normalize_fn.mean img_std = normalize_fn.std clip_model = tops.to_cuda(clip_model.visual) clip_preprocess = tops.to_cuda(torch.nn.Sequential( torchvision.transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC), torchvision.transforms.Normalize(img_mean, img_std) )) cache_directory = Path(cache_directory) if data_len is None: data_len = len(dataloader)*dataloader.batch_size fid_cache_path = cache_directory.joinpath("fid_stats_clip.pkl") has_fid_cache = fid_cache_path.is_file() if not has_fid_cache: fid_features_real = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device()) fid_features_fake = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device()) eidx = 0 n_samples_seen = 0 for batch in utils.tqdm_(iter(dataloader), desc="Computing FID CLIP."): sidx = eidx eidx = sidx + batch["img"].shape[0] n_samples_seen += batch["img"].shape[0] with torch.cuda.amp.autocast(tops.AMP()): fakes = generator(**batch)["img"] real_data = batch["img"] fakes = utils.denormalize_img(fakes) real_data = utils.denormalize_img(real_data) if not has_fid_cache: real_data = clip_preprocess(real_data) fid_features_real[sidx:eidx] = clip_model(real_data) fakes = clip_preprocess(fakes) fid_features_fake[sidx:eidx] = clip_model(fakes) fid_features_fake = fid_features_fake[:n_samples_seen] fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu() if has_fid_cache: if tops.rank() == 0: with open(fid_cache_path, "rb") as fp: fid_stat_real = pickle.load(fp) else: fid_features_real = fid_features_real[:n_samples_seen] fid_features_real = tops.all_gather_uneven(fid_features_real).cpu() assert fid_features_real.shape == fid_features_fake.shape if tops.rank() == 0: fid_stat_real = fid_features_to_statistics(fid_features_real) cache_directory.mkdir(exist_ok=True, parents=True) with open(fid_cache_path, "wb") as fp: pickle.dump(fid_stat_real, fp) if tops.rank() == 0: print("Starting calculation of fid from features of shape:", fid_features_fake.shape) fid_stat_fake = fid_features_to_statistics(fid_features_fake) fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"] return dict(fid_clip=fid_) return dict(fid_clip=-1)