| from __future__ import annotations |
|
|
| import random |
| import cv2 |
| import numpy as np |
|
|
| def random_color_distort( |
| img: np.ndarray, |
| brightness_delta: int = 32, |
| contrast_low: float = 0.5, |
| contrast_high: float = 1.5, |
| saturation_low: float = 0.5, |
| saturation_high: float = 1.5, |
| hue_delta: int = 18, |
| ) -> np.ndarray: |
| """SSD-style random colour jittering. |
| |
| Operates on an HWC **RGB uint8** image and returns the same format. |
| """ |
| cv_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
|
| def _convert(arr, alpha=1.0, beta=0.0): |
| arr = arr.astype(np.float32) * alpha + beta |
| return np.clip(arr, 0, 255).astype(np.uint8) |
|
|
| |
| if random.random() < 0.5: |
| cv_img = _convert(cv_img, beta=random.uniform(-brightness_delta, brightness_delta)) |
|
|
| |
| if random.random() < 0.5: |
| order = ["contrast", "saturation", "hue"] |
| else: |
| order = ["saturation", "hue", "contrast"] |
|
|
| for aug in order: |
| if aug == "contrast" and random.random() < 0.5: |
| cv_img = _convert(cv_img, alpha=random.uniform(contrast_low, contrast_high)) |
| elif aug == "saturation" and random.random() < 0.5: |
| hsv = cv2.cvtColor(cv_img, cv2.COLOR_BGR2HSV) |
| hsv[:, :, 1] = _convert(hsv[:, :, 1], alpha=random.uniform(saturation_low, saturation_high)) |
| cv_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) |
| elif aug == "hue" and random.random() < 0.5: |
| hsv = cv2.cvtColor(cv_img, cv2.COLOR_BGR2HSV) |
| hsv[:, :, 0] = ((hsv[:, :, 0].astype(int) + random.randint(-hue_delta, hue_delta)) % 180).astype(np.uint8) |
| cv_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) |
|
|
| return cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB) |
|
|
| def random_flip(image: np.ndarray, label: np.ndarray): |
| """Random horizontal and/or vertical flip.""" |
| if random.random() < 0.5: |
| image = np.ascontiguousarray(image[:, ::-1]) |
| label = np.ascontiguousarray(label[:, ::-1]) |
| if random.random() < 0.5: |
| image = np.ascontiguousarray(image[::-1]) |
| label = np.ascontiguousarray(label[::-1]) |
| return image, label |
|
|
| def random_rotate90(image: np.ndarray, label: np.ndarray): |
| """Random 0/90/180/270° rotation.""" |
| k = random.randint(0, 3) |
| if k > 0: |
| image = np.rot90(image, k, axes=(0, 1)).copy() |
| label = np.rot90(label, k, axes=(0, 1)).copy() |
| return image, label |
|
|
| def random_crop(image: np.ndarray, label: np.ndarray, crop_size: int): |
| """Extract a random crop of ``crop_size × crop_size`` from image/label.""" |
| h, w = image.shape[:2] |
| top = random.randint(0, h - crop_size) |
| left = random.randint(0, w - crop_size) |
| image = image[top : top + crop_size, left : left + crop_size] |
| label = label[top : top + crop_size, left : left + crop_size] |
| return image, label |
|
|
| def center_crop(image: np.ndarray, label: np.ndarray, crop_size: int): |
| """Center crop for validation.""" |
| h, w = image.shape[:2] |
| top = (h - crop_size) // 2 |
| left = (w - crop_size) // 2 |
| image = image[top : top + crop_size, left : left + crop_size] |
| label = label[top : top + crop_size, left : left + crop_size] |
| return image, label |
|
|
| def pad_to_size( |
| image: np.ndarray, |
| label: np.ndarray, |
| min_size: int, |
| pad_label_value: int = 0, |
| ) -> tuple[np.ndarray, np.ndarray]: |
| """Symmetric-pad image and label so both sides are ≥ min_size.""" |
| h, w = image.shape[:2] |
| if h >= min_size and w >= min_size: |
| return image, label |
|
|
| H = max(h, min_size) |
| W = max(w, min_size) |
| py1, px1 = (H - h) // 2, (W - w) // 2 |
| py2, px2 = H - h - py1, W - w - px1 |
|
|
| image = np.pad(image, ((py1, py2), (px1, px2), (0, 0)), mode="symmetric") |
| label = np.pad(label, ((py1, py2), (px1, px2)), mode="constant", constant_values=pad_label_value) |
| return image, label |
|
|
| def get_training_augmentation( |
| image: np.ndarray, |
| label: np.ndarray, |
| crop_size: int = 400, |
| color_distort: bool = True, |
| ) -> tuple[np.ndarray, np.ndarray]: |
| """Full training augmentation pipeline. |
| |
| Steps: |
| 1. Optional colour distortion |
| 2. Pad if smaller than crop_size |
| 3. Random flip |
| 4. Random 90° rotation |
| 5. Random crop |
| """ |
| if color_distort: |
| image = random_color_distort(image) |
|
|
| image, label = pad_to_size(image, label, crop_size, pad_label_value=0) |
| image, label = random_flip(image, label) |
| image, label = random_rotate90(image, label) |
| image, label = random_crop(image, label, crop_size) |
| return image, label |
|
|
| def get_validation_transform( |
| image: np.ndarray, |
| label: np.ndarray, |
| crop_size: int = 480, |
| ) -> tuple[np.ndarray, np.ndarray]: |
| """Validation transform: pad → center crop (deterministic).""" |
| image, label = pad_to_size(image, label, crop_size, pad_label_value=255) |
| image, label = center_crop(image, label, crop_size) |
| return image, label |
|
|