LandmarkDiff / landmarkdiff /synthetic /augmentation.py
dreamlessx's picture
Update landmarkdiff/synthetic/augmentation.py to v0.3.2
46ecbf8 verified
"""Clinical degradation augmentation pipeline.
Degrades clean FFHQ/CelebA-HQ images to match real clinical photo distribution.
Applied from day 1 of training — domain gap prevention, not afterthought.
Each sample gets 3-5 random augmentations from the pool.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
import cv2
import numpy as np
@dataclass(frozen=True)
class AugmentationConfig:
"""Configuration for a single augmentation."""
name: str
fn: Callable[[np.ndarray, np.random.Generator], np.ndarray]
probability: float
def point_source_lighting(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Simulate point-source clinical lighting from a random direction."""
h, w = image.shape[:2]
if h < 4 or w < 4:
return image
# Random light source position
lx = rng.uniform(0, w)
ly = rng.uniform(0, h)
intensity = rng.uniform(0.3, 0.7)
# Distance-based falloff
y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
dist = np.sqrt((x_grid - lx) ** 2 + (y_grid - ly) ** 2)
max_dist = np.sqrt(w ** 2 + h ** 2)
light_map = 1.0 - (dist / max_dist) * intensity
light_map = np.clip(light_map, 0.3, 1.0)
light_3ch = np.stack([light_map] * 3, axis=-1)
return np.clip(image.astype(np.float32) * light_3ch, 0, 255).astype(np.uint8)
def color_temperature_jitter(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Jitter color temperature +/- 2000K equivalent."""
shift = rng.uniform(-0.15, 0.15)
result = image.astype(np.float32)
if shift > 0:
# Warmer: boost red, reduce blue
result[:, :, 2] *= 1 + shift # red (BGR)
result[:, :, 0] *= 1 - shift * 0.5 # blue
else:
# Cooler: boost blue, reduce red
result[:, :, 0] *= 1 + abs(shift)
result[:, :, 2] *= 1 - abs(shift) * 0.5
return np.clip(result, 0, 255).astype(np.uint8)
def green_fluorescent_cast(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Add green fluorescent lighting cast (common in clinical settings)."""
intensity = rng.uniform(0.05, 0.15)
result = image.astype(np.float32)
result[:, :, 1] *= 1 + intensity # green channel boost
result[:, :, 0] *= 1 - intensity * 0.3 # slight blue reduction
result[:, :, 2] *= 1 - intensity * 0.3 # slight red reduction
return np.clip(result, 0, 255).astype(np.uint8)
def jpeg_compression(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Simulate JPEG compression artifacts (quality 40-85)."""
quality = int(rng.uniform(40, 85))
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
_, encoded = cv2.imencode(".jpg", image, encode_param)
decoded = cv2.imdecode(encoded, cv2.IMREAD_COLOR)
return decoded if decoded is not None else image
def gaussian_sensor_noise(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Add Gaussian sensor noise (sigma 5-25)."""
sigma = rng.uniform(5, 25)
noise = rng.normal(0, sigma, image.shape).astype(np.float32)
return np.clip(image.astype(np.float32) + noise, 0, 255).astype(np.uint8)
def barrel_distortion(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Apply barrel/pincushion distortion simulating phone camera lens."""
h, w = image.shape[:2]
if h < 4 or w < 4:
return image
k1 = rng.uniform(-0.2, 0.2)
fx = fy = max(w, h)
cx, cy = w / 2, h / 2
camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)
dist_coeffs = np.array([k1, 0, 0, 0, 0], dtype=np.float64)
map1, map2 = cv2.initUndistortRectifyMap(
camera_matrix, dist_coeffs, None, camera_matrix, (w, h), cv2.CV_32FC1
)
return cv2.remap(image, map1, map2, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
def motion_blur(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Slight motion blur (common in handheld clinical photos)."""
h, w = image.shape[:2]
if h < 4 or w < 4:
return image
size = int(rng.uniform(3, 7))
angle = rng.uniform(0, 180)
kernel = np.zeros((size, size))
kernel[size // 2, :] = 1.0 / size
M = cv2.getRotationMatrix2D((size / 2, size / 2), angle, 1)
kernel = cv2.warpAffine(kernel, M, (size, size))
ksum = kernel.sum()
if ksum > 0:
kernel = kernel / ksum
else:
kernel = np.zeros_like(kernel)
kernel[size // 2, size // 2] = 1.0
return cv2.filter2D(image, -1, kernel)
def vignette(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Add lens vignetting (darkened corners)."""
h, w = image.shape[:2]
if h < 4 or w < 4:
return image
strength = rng.uniform(0.3, 0.7)
y, x = np.mgrid[0:h, 0:w].astype(np.float32)
cx, cy = w / 2, h / 2
dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
max_dist = np.sqrt(cx ** 2 + cy ** 2)
mask = 1 - strength * (dist / max_dist) ** 2
mask = np.clip(mask, 0.3, 1.0)
mask_3ch = np.stack([mask] * 3, axis=-1)
return np.clip(image.astype(np.float32) * mask_3ch, 0, 255).astype(np.uint8)
# Augmentation pool with probabilities from the spec
AUGMENTATION_POOL: list[AugmentationConfig] = [
AugmentationConfig("point_source_lighting", point_source_lighting, 0.40),
AugmentationConfig("color_temperature", color_temperature_jitter, 0.60),
AugmentationConfig("green_fluorescent", green_fluorescent_cast, 0.25),
AugmentationConfig("jpeg_compression", jpeg_compression, 0.30),
AugmentationConfig("sensor_noise", gaussian_sensor_noise, 0.40),
AugmentationConfig("barrel_distortion", barrel_distortion, 0.30),
AugmentationConfig("motion_blur", motion_blur, 0.20),
AugmentationConfig("vignette", vignette, 0.25),
]
def apply_clinical_augmentation(
image: np.ndarray,
min_augmentations: int = 3,
max_augmentations: int = 5,
rng: np.random.Generator | None = None,
) -> np.ndarray:
"""Apply random clinical degradation augmentations to an image.
Each sample gets min_augmentations to max_augmentations from the pool,
selected by their individual probabilities.
Args:
image: BGR input image (clean FFHQ/CelebA-HQ).
min_augmentations: Minimum number of augmentations to apply.
max_augmentations: Maximum number of augmentations to apply.
rng: Random number generator.
Returns:
Degraded image matching clinical photo distribution.
"""
rng = rng or np.random.default_rng()
# Select augmentations by probability
selected = []
for aug in AUGMENTATION_POOL:
if rng.random() < aug.probability:
selected.append(aug)
# Ensure min/max bounds
if len(selected) < min_augmentations:
remaining = [a for a in AUGMENTATION_POOL if a not in selected]
rng.shuffle(remaining)
selected.extend(remaining[: min_augmentations - len(selected)])
if len(selected) > max_augmentations:
rng.shuffle(selected)
selected = selected[:max_augmentations]
# Apply in random order
rng.shuffle(selected)
result = image.copy()
for aug in selected:
result = aug.fn(result, rng)
return result