Spaces:
Running
Running
from typing import * | |
import random | |
import torch | |
from torch import Tensor | |
from torchmetrics.metric import Metric | |
from torchmetrics.utilities import rank_zero_info | |
import clip | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from PIL import Image | |
def read_image(imgid): | |
from pathlib import Path | |
vanilla = Path(imgid) | |
fixed = Path(f"data_en/images/{imgid}") | |
assert not (vanilla.exists() == fixed.exists()) # 両者共に存在/不在だと困る | |
path = vanilla if vanilla.exists() else fixed | |
return Image.open(path).convert("RGB") | |
class MID(): | |
def __init__(self,device="cuda"): | |
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=device) | |
self.device = device | |
def batchify(self, targets, batch_size): | |
return [targets[i:i+batch_size] for i in range(0,len(targets),batch_size)] | |
def __call__(self, mt_list, refs_list, img_list, no_ref=False): | |
B = 32 | |
mt_list, refs_list, img_list = [self.batchify(x,B) for x in [mt_list,refs_list,img_list]] | |
scores = [] | |
assert len(mt_list) == len(refs_list) == len(img_list) | |
for mt, refs, imgs in (pbar:= tqdm(zip(mt_list,refs_list, img_list),total=len(mt_list))): | |
pbar.set_description(f"MID") | |
imgs = [read_image(imgid) for imgid in imgs] | |
refs_token = [] | |
for ref_list in refs: | |
refs_token.append([clip.tokenize(ref,truncate=True).to(self.device) for ref in ref_list]) | |
refs = torch.cat([torch.cat(ref,dim=0) for ref in refs_token], dim=0) | |
mts = clip.tokenize([x for x in mt],truncate=True).to(self.device) | |
imgs = torch.cat([self.clip_preprocess(img).unsqueeze(0) for img in imgs],dim=0).to(self.device) | |
imgs = self.clip.encode_image(imgs) | |
mts = self.clip.encode_text(mts) | |
refs = self.clip.encode_text(refs) | |
compute_pmi(imgs,refs,mts) | |
return scores | |
def log_det(X): | |
eigenvalues = X.svd()[1] | |
return eigenvalues.log().sum() | |
def robust_inv(x, eps=0): | |
Id = torch.eye(x.shape[0]).to(x.device) | |
return (x + eps * Id).inverse() | |
def exp_smd(a, b, reduction=True): | |
a_inv = robust_inv(a) | |
if reduction: | |
assert b.shape[0] == b.shape[1] | |
return (a_inv @ b).trace() | |
else: | |
return (b @ a_inv @ b.t()).diag() | |
def compute_pmi(x: Tensor, y: Tensor, x0: Tensor, limit: int = 30000, | |
reduction: bool = True, full: bool = False) -> Tensor: | |
r""" | |
A numerical stable version of the MID score. | |
Args: | |
x (Tensor): features for real samples | |
y (Tensor): features for text samples | |
x0 (Tensor): features for fake samples | |
limit (int): limit the number of samples | |
reduction (bool): returns the expectation of PMI if true else sample-wise results | |
full (bool): use full samples from real images | |
Returns: | |
Scalar value of the mutual information divergence between the sets. | |
""" | |
N = x.shape[0] | |
excess = N - limit | |
if 0 < excess: | |
if not full: | |
x = x[:-excess] | |
y = y[:-excess] | |
x0 = x0[:-excess] | |
N = x.shape[0] | |
M = x0.shape[0] | |
assert N >= x.shape[1], "not full rank for matrix inversion!" | |
if x.shape[0] < 30000: | |
rank_zero_info("if it underperforms, please consider to use " | |
"the epsilon of 5e-4 or something else.") | |
z = torch.cat([x, y], dim=-1) | |
z0 = torch.cat([x0, y[:x0.shape[0]]], dim=-1) | |
x_mean = x.mean(dim=0, keepdim=True) | |
y_mean = y.mean(dim=0, keepdim=True) | |
z_mean = torch.cat([x_mean, y_mean], dim=-1) | |
x0_mean = x0.mean(dim=0, keepdim=True) | |
z0_mean = z0.mean(dim=0, keepdim=True) | |
X = (x - x_mean).t() @ (x - x_mean) / (N - 1) | |
Y = (y - y_mean).t() @ (y - y_mean) / (N - 1) | |
Z = (z - z_mean).t() @ (z - z_mean) / (N - 1) | |
X0 = (x0 - x_mean).t() @ (x0 - x_mean) / (M - 1) # use the reference mean | |
Z0 = (z0 - z_mean).t() @ (z0 - z_mean) / (M - 1) # use the reference mean | |
alternative_comp = False | |
# notice that it may have numerical unstability. we don't use this. | |
if alternative_comp: | |
def factorized_cov(x, m): | |
N = x.shape[0] | |
return (x.t() @ x - N * m.t() @ m) / (N - 1) | |
X0 = factorized_cov(x0, x_mean) | |
Z0 = factorized_cov(z0, z_mean) | |
# assert double precision | |
for _ in [X, Y, Z, X0, Z0]: | |
assert _.dtype == torch.float64 | |
# Expectation of PMI | |
mi = (log_det(X) + log_det(Y) - log_det(Z)) / 2 | |
rank_zero_info(f"MI of real images: {mi:.4f}") | |
# Squared Mahalanobis Distance terms | |
if reduction: | |
smd = (exp_smd(X, X0) + exp_smd(Y, Y) - exp_smd(Z, Z0)) / 2 | |
else: | |
smd = (exp_smd(X, x0 - x_mean, False) + exp_smd(Y, y - y_mean, False) | |
- exp_smd(Z, z0 - z_mean, False)) / 2 | |
mi = mi.unsqueeze(0) # for broadcasting | |
return mi + smd | |