dreamlessx commited on
Commit
28dc803
·
verified ·
1 Parent(s): 5be749f

Update landmarkdiff/fid.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/fid.py +81 -77
landmarkdiff/fid.py CHANGED
@@ -24,10 +24,10 @@ try:
24
  import torch
25
  import torch.nn as nn
26
  from torch.utils.data import DataLoader, Dataset
27
-
28
  HAS_TORCH = True
29
  except ImportError:
30
  HAS_TORCH = False
 
31
 
32
 
33
  def _load_inception_v3() -> Any:
@@ -42,95 +42,99 @@ def _load_inception_v3() -> Any:
42
  return model
43
 
44
 
45
- # Guard torch-dependent class and function definitions so the module
46
- # can be imported safely when torch is not installed.
47
- if HAS_TORCH:
48
-
49
- class ImageFolderDataset(Dataset): # type: ignore[misc]
50
- """Simple dataset that loads images from a directory."""
51
-
52
- def __init__(self, directory: str | Path, image_size: int = 299):
53
- self.directory = Path(directory)
54
- exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
55
- self.files = sorted(
56
- f for f in self.directory.iterdir() if f.suffix.lower() in exts and f.is_file()
57
- )
58
- self.image_size = image_size
59
-
60
- def __len__(self) -> int:
61
- return len(self.files)
62
-
63
- def __getitem__(self, idx: int) -> Any:
64
- import cv2
65
-
66
- img = cv2.imread(str(self.files[idx]))
67
- if img is None:
68
- # Return zeros if image can't be loaded
69
- return torch.zeros(3, self.image_size, self.image_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  img = cv2.resize(img, (self.image_size, self.image_size))
 
 
 
 
 
71
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
72
- # Normalize to [0, 1] then ImageNet normalize
73
- t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
74
- t = _imagenet_normalize(t)
75
- return t
76
-
77
- class NumpyArrayDataset(Dataset): # type: ignore[misc]
78
- """Dataset wrapping a list of numpy arrays."""
79
-
80
- def __init__(self, images: list[np.ndarray], image_size: int = 299):
81
- self.images = images
82
- self.image_size = image_size
83
-
84
- def __len__(self) -> int:
85
- return len(self.images)
86
-
87
- def __getitem__(self, idx: int) -> Any:
88
- import cv2
89
-
90
- img = self.images[idx]
91
- if img.shape[:2] != (self.image_size, self.image_size):
92
- img = cv2.resize(img, (self.image_size, self.image_size))
93
- if img.shape[2] == 3:
94
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
95
- t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
96
- t = _imagenet_normalize(t)
97
- return t
98
-
99
- def _imagenet_normalize(t: Any) -> Any:
100
- """Apply ImageNet normalization."""
101
- mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
102
- std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
103
- return (t - mean) / std
104
-
105
- @torch.no_grad()
106
- def _extract_features(
107
- model: Any,
108
- dataloader: Any,
109
- device: Any,
110
- ) -> np.ndarray:
111
- """Extract InceptionV3 pool3 features from a dataloader."""
112
- features = []
113
  for batch in dataloader:
114
  batch = batch.to(device)
115
  feat = model(batch)
116
  if isinstance(feat, tuple):
117
  feat = feat[0]
118
  features.append(feat.cpu().numpy())
119
- return np.concatenate(features, axis=0)
120
 
121
 
122
  def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
123
  """Compute mean and covariance of feature vectors."""
 
 
 
 
124
  mu = np.mean(features, axis=0)
125
  sigma = np.cov(features, rowvar=False)
126
  return mu, sigma
127
 
128
 
129
  def _calculate_fid(
130
- mu1: np.ndarray,
131
- sigma1: np.ndarray,
132
- mu2: np.ndarray,
133
- sigma2: np.ndarray,
134
  ) -> float:
135
  """Calculate FID given two sets of statistics.
136
 
@@ -146,7 +150,7 @@ def _calculate_fid(
146
  covmean = covmean.real
147
 
148
  fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
149
- return float(fid)
150
 
151
 
152
  def compute_fid_from_dirs(
@@ -183,10 +187,10 @@ def compute_fid_from_dirs(
183
  if len(real_ds) == 0 or len(gen_ds) == 0:
184
  raise ValueError("Need at least 1 image in each directory")
185
 
186
- real_loader = DataLoader(
187
- real_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True
188
- )
189
- gen_loader = DataLoader(gen_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
190
 
191
  real_features = _extract_features(model, real_loader, dev)
192
  gen_features = _extract_features(model, gen_loader, dev)
 
24
  import torch
25
  import torch.nn as nn
26
  from torch.utils.data import DataLoader, Dataset
 
27
  HAS_TORCH = True
28
  except ImportError:
29
  HAS_TORCH = False
30
+ Dataset = object # type: ignore[misc,assignment]
31
 
32
 
33
  def _load_inception_v3() -> Any:
 
42
  return model
43
 
44
 
45
+ class ImageFolderDataset(Dataset):
46
+ """Simple dataset that loads images from a directory."""
47
+
48
+ def __init__(self, directory: str | Path, image_size: int = 299):
49
+ self.directory = Path(directory)
50
+ exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
51
+ self.files = sorted(
52
+ f for f in self.directory.iterdir()
53
+ if f.suffix.lower() in exts and f.is_file()
54
+ )
55
+ self.image_size = image_size
56
+
57
+ def __len__(self) -> int:
58
+ return len(self.files)
59
+
60
+ def __getitem__(self, idx: int) -> Any:
61
+ import cv2
62
+ img = cv2.imread(str(self.files[idx]))
63
+ if img is None:
64
+ # Return zeros if image can't be loaded
65
+ return torch.zeros(3, self.image_size, self.image_size)
66
+ img = cv2.resize(img, (self.image_size, self.image_size))
67
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
68
+ # Normalize to [0, 1] then ImageNet normalize
69
+ t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
70
+ t = _imagenet_normalize(t)
71
+ return t
72
+
73
+
74
+ class NumpyArrayDataset(Dataset):
75
+ """Dataset wrapping a list of numpy arrays."""
76
+
77
+ def __init__(self, images: list[np.ndarray], image_size: int = 299):
78
+ self.images = images
79
+ self.image_size = image_size
80
+
81
+ def __len__(self) -> int:
82
+ return len(self.images)
83
+
84
+ def __getitem__(self, idx: int) -> Any:
85
+ import cv2
86
+ img = self.images[idx]
87
+ if img.shape[:2] != (self.image_size, self.image_size):
88
  img = cv2.resize(img, (self.image_size, self.image_size))
89
+ if img.ndim == 2:
90
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
91
+ elif img.shape[2] == 4:
92
+ img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
93
+ elif img.shape[2] == 3:
94
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
95
+ t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
96
+ t = _imagenet_normalize(t)
97
+ return t
98
+
99
+
100
+ def _imagenet_normalize(t: torch.Tensor) -> torch.Tensor:
101
+ """Apply ImageNet normalization."""
102
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
103
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
104
+ return (t - mean) / std
105
+
106
+
107
+ def _extract_features(
108
+ model: nn.Module,
109
+ dataloader: DataLoader,
110
+ device: torch.device,
111
+ ) -> np.ndarray:
112
+ """Extract InceptionV3 pool3 features from a dataloader."""
113
+ features = []
114
+ with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  for batch in dataloader:
116
  batch = batch.to(device)
117
  feat = model(batch)
118
  if isinstance(feat, tuple):
119
  feat = feat[0]
120
  features.append(feat.cpu().numpy())
121
+ return np.concatenate(features, axis=0)
122
 
123
 
124
  def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
125
  """Compute mean and covariance of feature vectors."""
126
+ if features.shape[0] < 2:
127
+ raise ValueError(
128
+ f"FID requires at least 2 images, got {features.shape[0]}"
129
+ )
130
  mu = np.mean(features, axis=0)
131
  sigma = np.cov(features, rowvar=False)
132
  return mu, sigma
133
 
134
 
135
  def _calculate_fid(
136
+ mu1: np.ndarray, sigma1: np.ndarray,
137
+ mu2: np.ndarray, sigma2: np.ndarray,
 
 
138
  ) -> float:
139
  """Calculate FID given two sets of statistics.
140
 
 
150
  covmean = covmean.real
151
 
152
  fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
153
+ return float(max(fid, 0.0))
154
 
155
 
156
  def compute_fid_from_dirs(
 
187
  if len(real_ds) == 0 or len(gen_ds) == 0:
188
  raise ValueError("Need at least 1 image in each directory")
189
 
190
+ real_loader = DataLoader(real_ds, batch_size=batch_size,
191
+ num_workers=num_workers, pin_memory=True)
192
+ gen_loader = DataLoader(gen_ds, batch_size=batch_size,
193
+ num_workers=num_workers, pin_memory=True)
194
 
195
  real_features = _extract_features(model, real_loader, dev)
196
  gen_features = _extract_features(model, gen_loader, dev)