Spaces:
Running
Running
Update landmarkdiff/fid.py to v0.3.2
Browse files- landmarkdiff/fid.py +81 -77
landmarkdiff/fid.py
CHANGED
|
@@ -24,10 +24,10 @@ try:
|
|
| 24 |
import torch
|
| 25 |
import torch.nn as nn
|
| 26 |
from torch.utils.data import DataLoader, Dataset
|
| 27 |
-
|
| 28 |
HAS_TORCH = True
|
| 29 |
except ImportError:
|
| 30 |
HAS_TORCH = False
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
def _load_inception_v3() -> Any:
|
|
@@ -42,95 +42,99 @@ def _load_inception_v3() -> Any:
|
|
| 42 |
return model
|
| 43 |
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
"""
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
img = cv2.resize(img, (self.image_size, self.image_size))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
img = cv2.resize(img, (self.image_size, self.image_size))
|
| 93 |
-
if img.shape[2] == 3:
|
| 94 |
-
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 95 |
-
t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
|
| 96 |
-
t = _imagenet_normalize(t)
|
| 97 |
-
return t
|
| 98 |
-
|
| 99 |
-
def _imagenet_normalize(t: Any) -> Any:
|
| 100 |
-
"""Apply ImageNet normalization."""
|
| 101 |
-
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 102 |
-
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 103 |
-
return (t - mean) / std
|
| 104 |
-
|
| 105 |
-
@torch.no_grad()
|
| 106 |
-
def _extract_features(
|
| 107 |
-
model: Any,
|
| 108 |
-
dataloader: Any,
|
| 109 |
-
device: Any,
|
| 110 |
-
) -> np.ndarray:
|
| 111 |
-
"""Extract InceptionV3 pool3 features from a dataloader."""
|
| 112 |
-
features = []
|
| 113 |
for batch in dataloader:
|
| 114 |
batch = batch.to(device)
|
| 115 |
feat = model(batch)
|
| 116 |
if isinstance(feat, tuple):
|
| 117 |
feat = feat[0]
|
| 118 |
features.append(feat.cpu().numpy())
|
| 119 |
-
|
| 120 |
|
| 121 |
|
| 122 |
def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 123 |
"""Compute mean and covariance of feature vectors."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
mu = np.mean(features, axis=0)
|
| 125 |
sigma = np.cov(features, rowvar=False)
|
| 126 |
return mu, sigma
|
| 127 |
|
| 128 |
|
| 129 |
def _calculate_fid(
|
| 130 |
-
mu1: np.ndarray,
|
| 131 |
-
|
| 132 |
-
mu2: np.ndarray,
|
| 133 |
-
sigma2: np.ndarray,
|
| 134 |
) -> float:
|
| 135 |
"""Calculate FID given two sets of statistics.
|
| 136 |
|
|
@@ -146,7 +150,7 @@ def _calculate_fid(
|
|
| 146 |
covmean = covmean.real
|
| 147 |
|
| 148 |
fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
|
| 149 |
-
return float(fid)
|
| 150 |
|
| 151 |
|
| 152 |
def compute_fid_from_dirs(
|
|
@@ -183,10 +187,10 @@ def compute_fid_from_dirs(
|
|
| 183 |
if len(real_ds) == 0 or len(gen_ds) == 0:
|
| 184 |
raise ValueError("Need at least 1 image in each directory")
|
| 185 |
|
| 186 |
-
real_loader = DataLoader(
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
|
| 191 |
real_features = _extract_features(model, real_loader, dev)
|
| 192 |
gen_features = _extract_features(model, gen_loader, dev)
|
|
|
|
| 24 |
import torch
|
| 25 |
import torch.nn as nn
|
| 26 |
from torch.utils.data import DataLoader, Dataset
|
|
|
|
| 27 |
HAS_TORCH = True
|
| 28 |
except ImportError:
|
| 29 |
HAS_TORCH = False
|
| 30 |
+
Dataset = object # type: ignore[misc,assignment]
|
| 31 |
|
| 32 |
|
| 33 |
def _load_inception_v3() -> Any:
|
|
|
|
| 42 |
return model
|
| 43 |
|
| 44 |
|
| 45 |
+
class ImageFolderDataset(Dataset):
|
| 46 |
+
"""Simple dataset that loads images from a directory."""
|
| 47 |
+
|
| 48 |
+
def __init__(self, directory: str | Path, image_size: int = 299):
|
| 49 |
+
self.directory = Path(directory)
|
| 50 |
+
exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
|
| 51 |
+
self.files = sorted(
|
| 52 |
+
f for f in self.directory.iterdir()
|
| 53 |
+
if f.suffix.lower() in exts and f.is_file()
|
| 54 |
+
)
|
| 55 |
+
self.image_size = image_size
|
| 56 |
+
|
| 57 |
+
def __len__(self) -> int:
|
| 58 |
+
return len(self.files)
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, idx: int) -> Any:
|
| 61 |
+
import cv2
|
| 62 |
+
img = cv2.imread(str(self.files[idx]))
|
| 63 |
+
if img is None:
|
| 64 |
+
# Return zeros if image can't be loaded
|
| 65 |
+
return torch.zeros(3, self.image_size, self.image_size)
|
| 66 |
+
img = cv2.resize(img, (self.image_size, self.image_size))
|
| 67 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 68 |
+
# Normalize to [0, 1] then ImageNet normalize
|
| 69 |
+
t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
|
| 70 |
+
t = _imagenet_normalize(t)
|
| 71 |
+
return t
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class NumpyArrayDataset(Dataset):
|
| 75 |
+
"""Dataset wrapping a list of numpy arrays."""
|
| 76 |
+
|
| 77 |
+
def __init__(self, images: list[np.ndarray], image_size: int = 299):
|
| 78 |
+
self.images = images
|
| 79 |
+
self.image_size = image_size
|
| 80 |
+
|
| 81 |
+
def __len__(self) -> int:
|
| 82 |
+
return len(self.images)
|
| 83 |
+
|
| 84 |
+
def __getitem__(self, idx: int) -> Any:
|
| 85 |
+
import cv2
|
| 86 |
+
img = self.images[idx]
|
| 87 |
+
if img.shape[:2] != (self.image_size, self.image_size):
|
| 88 |
img = cv2.resize(img, (self.image_size, self.image_size))
|
| 89 |
+
if img.ndim == 2:
|
| 90 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
| 91 |
+
elif img.shape[2] == 4:
|
| 92 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
| 93 |
+
elif img.shape[2] == 3:
|
| 94 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 95 |
+
t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
|
| 96 |
+
t = _imagenet_normalize(t)
|
| 97 |
+
return t
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _imagenet_normalize(t: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
"""Apply ImageNet normalization."""
|
| 102 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 103 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 104 |
+
return (t - mean) / std
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _extract_features(
|
| 108 |
+
model: nn.Module,
|
| 109 |
+
dataloader: DataLoader,
|
| 110 |
+
device: torch.device,
|
| 111 |
+
) -> np.ndarray:
|
| 112 |
+
"""Extract InceptionV3 pool3 features from a dataloader."""
|
| 113 |
+
features = []
|
| 114 |
+
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
for batch in dataloader:
|
| 116 |
batch = batch.to(device)
|
| 117 |
feat = model(batch)
|
| 118 |
if isinstance(feat, tuple):
|
| 119 |
feat = feat[0]
|
| 120 |
features.append(feat.cpu().numpy())
|
| 121 |
+
return np.concatenate(features, axis=0)
|
| 122 |
|
| 123 |
|
| 124 |
def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 125 |
"""Compute mean and covariance of feature vectors."""
|
| 126 |
+
if features.shape[0] < 2:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"FID requires at least 2 images, got {features.shape[0]}"
|
| 129 |
+
)
|
| 130 |
mu = np.mean(features, axis=0)
|
| 131 |
sigma = np.cov(features, rowvar=False)
|
| 132 |
return mu, sigma
|
| 133 |
|
| 134 |
|
| 135 |
def _calculate_fid(
|
| 136 |
+
mu1: np.ndarray, sigma1: np.ndarray,
|
| 137 |
+
mu2: np.ndarray, sigma2: np.ndarray,
|
|
|
|
|
|
|
| 138 |
) -> float:
|
| 139 |
"""Calculate FID given two sets of statistics.
|
| 140 |
|
|
|
|
| 150 |
covmean = covmean.real
|
| 151 |
|
| 152 |
fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
|
| 153 |
+
return float(max(fid, 0.0))
|
| 154 |
|
| 155 |
|
| 156 |
def compute_fid_from_dirs(
|
|
|
|
| 187 |
if len(real_ds) == 0 or len(gen_ds) == 0:
|
| 188 |
raise ValueError("Need at least 1 image in each directory")
|
| 189 |
|
| 190 |
+
real_loader = DataLoader(real_ds, batch_size=batch_size,
|
| 191 |
+
num_workers=num_workers, pin_memory=True)
|
| 192 |
+
gen_loader = DataLoader(gen_ds, batch_size=batch_size,
|
| 193 |
+
num_workers=num_workers, pin_memory=True)
|
| 194 |
|
| 195 |
real_features = _extract_features(model, real_loader, dev)
|
| 196 |
gen_features = _extract_features(model, gen_loader, dev)
|