dreamlessx commited on
Commit
46ecbf8
·
verified ·
1 Parent(s): 1bb473c

Update landmarkdiff/synthetic/augmentation.py to v0.3.2

Browse files
landmarkdiff/synthetic/augmentation.py CHANGED
@@ -1,8 +1,9 @@
1
- """Clinical degradation augmentations.
2
 
3
- Degrades clean FFHQ/CelebA-HQ to match real clinical photo distribution.
4
- Applied from day 1 - domain gap prevention, not afterthought.
5
- 3-5 random augmentations per sample.
 
6
  """
7
 
8
  from __future__ import annotations
@@ -26,6 +27,8 @@ class AugmentationConfig:
26
  def point_source_lighting(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
27
  """Simulate point-source clinical lighting from a random direction."""
28
  h, w = image.shape[:2]
 
 
29
 
30
  # Random light source position
31
  lx = rng.uniform(0, w)
@@ -35,7 +38,7 @@ def point_source_lighting(image: np.ndarray, rng: np.random.Generator) -> np.nda
35
  # Distance-based falloff
36
  y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
37
  dist = np.sqrt((x_grid - lx) ** 2 + (y_grid - ly) ** 2)
38
- max_dist = np.sqrt(w**2 + h**2)
39
  light_map = 1.0 - (dist / max_dist) * intensity
40
 
41
  light_map = np.clip(light_map, 0.3, 1.0)
@@ -76,7 +79,8 @@ def jpeg_compression(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
76
  quality = int(rng.uniform(40, 85))
77
  encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
78
  _, encoded = cv2.imencode(".jpg", image, encode_param)
79
- return cv2.imdecode(encoded, cv2.IMREAD_COLOR)
 
80
 
81
 
82
  def gaussian_sensor_noise(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
@@ -89,6 +93,8 @@ def gaussian_sensor_noise(image: np.ndarray, rng: np.random.Generator) -> np.nda
89
  def barrel_distortion(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
90
  """Apply barrel/pincushion distortion simulating phone camera lens."""
91
  h, w = image.shape[:2]
 
 
92
  k1 = rng.uniform(-0.2, 0.2)
93
 
94
  fx = fy = max(w, h)
@@ -105,6 +111,9 @@ def barrel_distortion(image: np.ndarray, rng: np.random.Generator) -> np.ndarray
105
 
106
  def motion_blur(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
107
  """Slight motion blur (common in handheld clinical photos)."""
 
 
 
108
  size = int(rng.uniform(3, 7))
109
  angle = rng.uniform(0, 180)
110
 
@@ -117,7 +126,6 @@ def motion_blur(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
117
  if ksum > 0:
118
  kernel = kernel / ksum
119
  else:
120
- # rotation can zero out the kernel - fall back to identity
121
  kernel = np.zeros_like(kernel)
122
  kernel[size // 2, size // 2] = 1.0
123
 
@@ -127,12 +135,14 @@ def motion_blur(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
127
  def vignette(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
128
  """Add lens vignetting (darkened corners)."""
129
  h, w = image.shape[:2]
 
 
130
  strength = rng.uniform(0.3, 0.7)
131
 
132
  y, x = np.mgrid[0:h, 0:w].astype(np.float32)
133
  cx, cy = w / 2, h / 2
134
  dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
135
- max_dist = np.sqrt(cx**2 + cy**2)
136
 
137
  mask = 1 - strength * (dist / max_dist) ** 2
138
  mask = np.clip(mask, 0.3, 1.0)
@@ -160,7 +170,20 @@ def apply_clinical_augmentation(
160
  max_augmentations: int = 5,
161
  rng: np.random.Generator | None = None,
162
  ) -> np.ndarray:
163
- """Apply random clinical degradation augmentations to an image."""
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  rng = rng or np.random.default_rng()
165
 
166
  # Select augmentations by probability
 
1
+ """Clinical degradation augmentation pipeline.
2
 
3
+ Degrades clean FFHQ/CelebA-HQ images to match real clinical photo distribution.
4
+ Applied from day 1 of training — domain gap prevention, not afterthought.
5
+
6
+ Each sample gets 3-5 random augmentations from the pool.
7
  """
8
 
9
  from __future__ import annotations
 
27
  def point_source_lighting(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
28
  """Simulate point-source clinical lighting from a random direction."""
29
  h, w = image.shape[:2]
30
+ if h < 4 or w < 4:
31
+ return image
32
 
33
  # Random light source position
34
  lx = rng.uniform(0, w)
 
38
  # Distance-based falloff
39
  y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
40
  dist = np.sqrt((x_grid - lx) ** 2 + (y_grid - ly) ** 2)
41
+ max_dist = np.sqrt(w ** 2 + h ** 2)
42
  light_map = 1.0 - (dist / max_dist) * intensity
43
 
44
  light_map = np.clip(light_map, 0.3, 1.0)
 
79
  quality = int(rng.uniform(40, 85))
80
  encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
81
  _, encoded = cv2.imencode(".jpg", image, encode_param)
82
+ decoded = cv2.imdecode(encoded, cv2.IMREAD_COLOR)
83
+ return decoded if decoded is not None else image
84
 
85
 
86
  def gaussian_sensor_noise(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
 
93
  def barrel_distortion(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
94
  """Apply barrel/pincushion distortion simulating phone camera lens."""
95
  h, w = image.shape[:2]
96
+ if h < 4 or w < 4:
97
+ return image
98
  k1 = rng.uniform(-0.2, 0.2)
99
 
100
  fx = fy = max(w, h)
 
111
 
112
  def motion_blur(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
113
  """Slight motion blur (common in handheld clinical photos)."""
114
+ h, w = image.shape[:2]
115
+ if h < 4 or w < 4:
116
+ return image
117
  size = int(rng.uniform(3, 7))
118
  angle = rng.uniform(0, 180)
119
 
 
126
  if ksum > 0:
127
  kernel = kernel / ksum
128
  else:
 
129
  kernel = np.zeros_like(kernel)
130
  kernel[size // 2, size // 2] = 1.0
131
 
 
135
  def vignette(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
136
  """Add lens vignetting (darkened corners)."""
137
  h, w = image.shape[:2]
138
+ if h < 4 or w < 4:
139
+ return image
140
  strength = rng.uniform(0.3, 0.7)
141
 
142
  y, x = np.mgrid[0:h, 0:w].astype(np.float32)
143
  cx, cy = w / 2, h / 2
144
  dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
145
+ max_dist = np.sqrt(cx ** 2 + cy ** 2)
146
 
147
  mask = 1 - strength * (dist / max_dist) ** 2
148
  mask = np.clip(mask, 0.3, 1.0)
 
170
  max_augmentations: int = 5,
171
  rng: np.random.Generator | None = None,
172
  ) -> np.ndarray:
173
+ """Apply random clinical degradation augmentations to an image.
174
+
175
+ Each sample gets min_augmentations to max_augmentations from the pool,
176
+ selected by their individual probabilities.
177
+
178
+ Args:
179
+ image: BGR input image (clean FFHQ/CelebA-HQ).
180
+ min_augmentations: Minimum number of augmentations to apply.
181
+ max_augmentations: Maximum number of augmentations to apply.
182
+ rng: Random number generator.
183
+
184
+ Returns:
185
+ Degraded image matching clinical photo distribution.
186
+ """
187
  rng = rng or np.random.default_rng()
188
 
189
  # Select augmentations by probability