Spaces:
Sleeping
Sleeping
Update landmarkdiff/synthetic/augmentation.py to v0.3.2
Browse files
landmarkdiff/synthetic/augmentation.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
-
"""Clinical degradation
|
| 2 |
|
| 3 |
-
Degrades clean FFHQ/CelebA-HQ to match real clinical photo distribution.
|
| 4 |
-
Applied from day 1
|
| 5 |
-
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
|
@@ -26,6 +27,8 @@ class AugmentationConfig:
|
|
| 26 |
def point_source_lighting(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
| 27 |
"""Simulate point-source clinical lighting from a random direction."""
|
| 28 |
h, w = image.shape[:2]
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# Random light source position
|
| 31 |
lx = rng.uniform(0, w)
|
|
@@ -35,7 +38,7 @@ def point_source_lighting(image: np.ndarray, rng: np.random.Generator) -> np.nda
|
|
| 35 |
# Distance-based falloff
|
| 36 |
y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
|
| 37 |
dist = np.sqrt((x_grid - lx) ** 2 + (y_grid - ly) ** 2)
|
| 38 |
-
max_dist = np.sqrt(w**2 + h**2)
|
| 39 |
light_map = 1.0 - (dist / max_dist) * intensity
|
| 40 |
|
| 41 |
light_map = np.clip(light_map, 0.3, 1.0)
|
|
@@ -76,7 +79,8 @@ def jpeg_compression(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
| 76 |
quality = int(rng.uniform(40, 85))
|
| 77 |
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
|
| 78 |
_, encoded = cv2.imencode(".jpg", image, encode_param)
|
| 79 |
-
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
def gaussian_sensor_noise(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
@@ -89,6 +93,8 @@ def gaussian_sensor_noise(image: np.ndarray, rng: np.random.Generator) -> np.nda
|
|
| 89 |
def barrel_distortion(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
| 90 |
"""Apply barrel/pincushion distortion simulating phone camera lens."""
|
| 91 |
h, w = image.shape[:2]
|
|
|
|
|
|
|
| 92 |
k1 = rng.uniform(-0.2, 0.2)
|
| 93 |
|
| 94 |
fx = fy = max(w, h)
|
|
@@ -105,6 +111,9 @@ def barrel_distortion(image: np.ndarray, rng: np.random.Generator) -> np.ndarray
|
|
| 105 |
|
| 106 |
def motion_blur(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
| 107 |
"""Slight motion blur (common in handheld clinical photos)."""
|
|
|
|
|
|
|
|
|
|
| 108 |
size = int(rng.uniform(3, 7))
|
| 109 |
angle = rng.uniform(0, 180)
|
| 110 |
|
|
@@ -117,7 +126,6 @@ def motion_blur(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
| 117 |
if ksum > 0:
|
| 118 |
kernel = kernel / ksum
|
| 119 |
else:
|
| 120 |
-
# rotation can zero out the kernel - fall back to identity
|
| 121 |
kernel = np.zeros_like(kernel)
|
| 122 |
kernel[size // 2, size // 2] = 1.0
|
| 123 |
|
|
@@ -127,12 +135,14 @@ def motion_blur(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
| 127 |
def vignette(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
| 128 |
"""Add lens vignetting (darkened corners)."""
|
| 129 |
h, w = image.shape[:2]
|
|
|
|
|
|
|
| 130 |
strength = rng.uniform(0.3, 0.7)
|
| 131 |
|
| 132 |
y, x = np.mgrid[0:h, 0:w].astype(np.float32)
|
| 133 |
cx, cy = w / 2, h / 2
|
| 134 |
dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
|
| 135 |
-
max_dist = np.sqrt(cx**2 + cy**2)
|
| 136 |
|
| 137 |
mask = 1 - strength * (dist / max_dist) ** 2
|
| 138 |
mask = np.clip(mask, 0.3, 1.0)
|
|
@@ -160,7 +170,20 @@ def apply_clinical_augmentation(
|
|
| 160 |
max_augmentations: int = 5,
|
| 161 |
rng: np.random.Generator | None = None,
|
| 162 |
) -> np.ndarray:
|
| 163 |
-
"""Apply random clinical degradation augmentations to an image.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
rng = rng or np.random.default_rng()
|
| 165 |
|
| 166 |
# Select augmentations by probability
|
|
|
|
| 1 |
+
"""Clinical degradation augmentation pipeline.
|
| 2 |
|
| 3 |
+
Degrades clean FFHQ/CelebA-HQ images to match real clinical photo distribution.
|
| 4 |
+
Applied from day 1 of training — domain gap prevention, not afterthought.
|
| 5 |
+
|
| 6 |
+
Each sample gets 3-5 random augmentations from the pool.
|
| 7 |
"""
|
| 8 |
|
| 9 |
from __future__ import annotations
|
|
|
|
| 27 |
def point_source_lighting(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
| 28 |
"""Simulate point-source clinical lighting from a random direction."""
|
| 29 |
h, w = image.shape[:2]
|
| 30 |
+
if h < 4 or w < 4:
|
| 31 |
+
return image
|
| 32 |
|
| 33 |
# Random light source position
|
| 34 |
lx = rng.uniform(0, w)
|
|
|
|
| 38 |
# Distance-based falloff
|
| 39 |
y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
|
| 40 |
dist = np.sqrt((x_grid - lx) ** 2 + (y_grid - ly) ** 2)
|
| 41 |
+
max_dist = np.sqrt(w ** 2 + h ** 2)
|
| 42 |
light_map = 1.0 - (dist / max_dist) * intensity
|
| 43 |
|
| 44 |
light_map = np.clip(light_map, 0.3, 1.0)
|
|
|
|
| 79 |
quality = int(rng.uniform(40, 85))
|
| 80 |
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
|
| 81 |
_, encoded = cv2.imencode(".jpg", image, encode_param)
|
| 82 |
+
decoded = cv2.imdecode(encoded, cv2.IMREAD_COLOR)
|
| 83 |
+
return decoded if decoded is not None else image
|
| 84 |
|
| 85 |
|
| 86 |
def gaussian_sensor_noise(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
|
|
| 93 |
def barrel_distortion(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
| 94 |
"""Apply barrel/pincushion distortion simulating phone camera lens."""
|
| 95 |
h, w = image.shape[:2]
|
| 96 |
+
if h < 4 or w < 4:
|
| 97 |
+
return image
|
| 98 |
k1 = rng.uniform(-0.2, 0.2)
|
| 99 |
|
| 100 |
fx = fy = max(w, h)
|
|
|
|
| 111 |
|
| 112 |
def motion_blur(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
| 113 |
"""Slight motion blur (common in handheld clinical photos)."""
|
| 114 |
+
h, w = image.shape[:2]
|
| 115 |
+
if h < 4 or w < 4:
|
| 116 |
+
return image
|
| 117 |
size = int(rng.uniform(3, 7))
|
| 118 |
angle = rng.uniform(0, 180)
|
| 119 |
|
|
|
|
| 126 |
if ksum > 0:
|
| 127 |
kernel = kernel / ksum
|
| 128 |
else:
|
|
|
|
| 129 |
kernel = np.zeros_like(kernel)
|
| 130 |
kernel[size // 2, size // 2] = 1.0
|
| 131 |
|
|
|
|
| 135 |
def vignette(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
| 136 |
"""Add lens vignetting (darkened corners)."""
|
| 137 |
h, w = image.shape[:2]
|
| 138 |
+
if h < 4 or w < 4:
|
| 139 |
+
return image
|
| 140 |
strength = rng.uniform(0.3, 0.7)
|
| 141 |
|
| 142 |
y, x = np.mgrid[0:h, 0:w].astype(np.float32)
|
| 143 |
cx, cy = w / 2, h / 2
|
| 144 |
dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
|
| 145 |
+
max_dist = np.sqrt(cx ** 2 + cy ** 2)
|
| 146 |
|
| 147 |
mask = 1 - strength * (dist / max_dist) ** 2
|
| 148 |
mask = np.clip(mask, 0.3, 1.0)
|
|
|
|
| 170 |
max_augmentations: int = 5,
|
| 171 |
rng: np.random.Generator | None = None,
|
| 172 |
) -> np.ndarray:
|
| 173 |
+
"""Apply random clinical degradation augmentations to an image.
|
| 174 |
+
|
| 175 |
+
Each sample gets min_augmentations to max_augmentations from the pool,
|
| 176 |
+
selected by their individual probabilities.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
image: BGR input image (clean FFHQ/CelebA-HQ).
|
| 180 |
+
min_augmentations: Minimum number of augmentations to apply.
|
| 181 |
+
max_augmentations: Maximum number of augmentations to apply.
|
| 182 |
+
rng: Random number generator.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Degraded image matching clinical photo distribution.
|
| 186 |
+
"""
|
| 187 |
rng = rng or np.random.default_rng()
|
| 188 |
|
| 189 |
# Select augmentations by probability
|