TTP / opencd /datasets /transforms /transforms.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) Open-CD. All rights reserved.
import copy
import warnings
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import cv2
import mmcv
import numpy as np
from mmcv.image.geometric import _scale_size
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmengine.utils import is_list_of, is_seq_of, is_str, is_tuple_of
from numpy import random
from scipy.ndimage import gaussian_filter
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
from opencd.registry import TRANSFORMS
try:
import albumentations
from albumentations import Compose
except ImportError:
albumentations = None
Compose = None
@TRANSFORMS.register_module()
class MultiImgResizeToMultiple(BaseTransform):
"""Resize images & seg to multiple of divisor.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- img_shape
- pad_shape
Args:
size_divisor (int): images and gt seg maps need to resize to multiple
of size_divisor. Default: 32.
interpolation (str, optional): The interpolation mode of image resize.
Default: None
"""
def __init__(self, size_divisor=32, interpolation=None):
self.size_divisor = size_divisor
self.interpolation = interpolation
def transform(self, results: dict) -> dict:
"""Call function to resize images, semantic segmentation map to
multiple of size divisor.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img_shape', 'pad_shape' keys are updated.
"""
# Align image to multiple of size divisor.
imgs = results['img']
imgs = [
mmcv.imresize_to_multiple(
img,
self.size_divisor,
scale_factor=1,
interpolation=self.interpolation
if self.interpolation else 'bilinear') for img in imgs]
results['img'] = imgs
results['img_shape'] = imgs[0].shape
results['pad_shape'] = imgs[0].shape
# Align segmentation map to multiple of size divisor.
for key in results.get('seg_fields', []):
gt_seg = results[key]
gt_seg = mmcv.imresize_to_multiple(
gt_seg,
self.size_divisor,
scale_factor=1,
interpolation='nearest')
results[key] = gt_seg
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(size_divisor={self.size_divisor}, '
f'interpolation={self.interpolation})')
return repr_str
@TRANSFORMS.register_module()
class MultiImgRerange(BaseTransform):
"""Rerange the image pixel value.
Required Keys:
- img
Modified Keys:
- img
Args:
min_value (float or int): Minimum value of the reranged image.
Default: 0.
max_value (float or int): Maximum value of the reranged image.
Default: 255.
"""
def __init__(self, min_value=0, max_value=255):
assert isinstance(min_value, float) or isinstance(min_value, int)
assert isinstance(max_value, float) or isinstance(max_value, int)
assert min_value < max_value
self.min_value = min_value
self.max_value = max_value
def transform(self, results: dict) -> dict:
"""Call function to rerange images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Reranged results.
"""
def _rerange(img):
img_min_value = np.min(img)
img_max_value = np.max(img)
assert img_min_value < img_max_value
# rerange to [0, 1]
img = (img - img_min_value) / (img_max_value - img_min_value)
# rerange to [min_value, max_value]
img = img * (self.max_value - self.min_value) + self.min_value
return img
results['img'] = [_rerange(img) for img in results['img']]
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(min_value={self.min_value}, max_value={self.max_value})'
return repr_str
@TRANSFORMS.register_module()
class MultiImgCLAHE(BaseTransform):
"""Use CLAHE method to process the image.
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
Graphics Gems, 1994:474-485.` for more information.
Required Keys:
- img
Modified Keys:
- img
Args:
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
Input image will be divided into equally sized rectangular tiles.
It defines the number of tiles in row and column. Default: (8, 8).
"""
def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
assert isinstance(clip_limit, (float, int))
self.clip_limit = clip_limit
assert is_tuple_of(tile_grid_size, int)
assert len(tile_grid_size) == 2
self.tile_grid_size = tile_grid_size
def transform(self, results: dict) -> dict:
"""Call function to Use CLAHE method process images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Processed results.
"""
def _clane(img):
for i in range(img.shape[2]):
img[:, :, i] = mmcv.clahe(
np.array(img[:, :, i], dtype=np.uint8),
self.clip_limit, self.tile_grid_size)
return img
results['img'] = [_clane(img) for img in results['img']]
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(clip_limit={self.clip_limit}, '\
f'tile_grid_size={self.tile_grid_size})'
return repr_str
@TRANSFORMS.register_module()
class MultiImgRandomCrop(BaseTransform):
"""Random crop the image & seg.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- img_shape
- gt_seg_map
Args:
crop_size (Union[int, Tuple[int, int]]): Expected size after cropping
with the format of (h, w). If set to an integer, then cropping
width and height are equal to this integer.
cat_max_ratio (float): The maximum ratio that single category could
occupy.
ignore_index (int): The label index to be ignored. Default: 255
"""
def __init__(self,
crop_size: Union[int, Tuple[int, int]],
cat_max_ratio: float = 1.,
ignore_index: int = 255):
super().__init__()
assert isinstance(crop_size, int) or (
isinstance(crop_size, tuple) and len(crop_size) == 2
), 'The expected crop_size is an integer, or a tuple containing two '
'intergers'
if isinstance(crop_size, int):
crop_size = (crop_size, crop_size)
assert crop_size[0] > 0 and crop_size[1] > 0
self.crop_size = crop_size
self.cat_max_ratio = cat_max_ratio
self.ignore_index = ignore_index
@cache_randomness
def crop_bbox(self, results: dict) -> tuple:
"""get a crop bounding box.
Args:
results (dict): Result dict from loading pipeline.
Returns:
tuple: Coordinates of the cropped image.
"""
def generate_crop_bbox(img: np.ndarray) -> tuple:
"""Randomly get a crop bounding box.
Args:
img (np.ndarray): Original input image.
Returns:
tuple: Coordinates of the cropped image.
"""
margin_h = max(img.shape[0] - self.crop_size[0], 0)
margin_w = max(img.shape[1] - self.crop_size[1], 0)
offset_h = np.random.randint(0, margin_h + 1)
offset_w = np.random.randint(0, margin_w + 1)
crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
return crop_y1, crop_y2, crop_x1, crop_x2
img = results['img'][0]
crop_bbox = generate_crop_bbox(img)
if self.cat_max_ratio < 1.:
# Repeat 10 times
for _ in range(10):
seg_temp = self.crop(results['gt_seg_map'], crop_bbox)
labels, cnt = np.unique(seg_temp, return_counts=True)
cnt = cnt[labels != self.ignore_index]
if len(cnt) > 1 and np.max(cnt) / np.sum(
cnt) < self.cat_max_ratio:
break
crop_bbox = generate_crop_bbox(img)
return crop_bbox
def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray:
"""Crop from ``img``
Args:
img (np.ndarray): Original input image.
crop_bbox (tuple): Coordinates of the cropped image.
Returns:
np.ndarray: The cropped image.
"""
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
return img
def transform(self, results: dict) -> dict:
"""Transform function to randomly crop images, semantic segmentation
maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
"""
crop_bbox = self.crop_bbox(results)
# crop the image
imgs = [self.crop(img, crop_bbox) for img in results['img']]
# crop semantic seg
for key in results.get('seg_fields', []):
results[key] = self.crop(results[key], crop_bbox)
results['img'] = imgs
results['img_shape'] = imgs[0].shape
return results
def __repr__(self):
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
@TRANSFORMS.register_module()
class MultiImgRandomRotate(BaseTransform):
"""Rotate the image & seg.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- gt_seg_map
Args:
prob (float): The rotation probability.
degree (float, tuple[float]): Range of degrees to select from. If
degree is a number instead of tuple like (min, max),
the range of degree will be (``-degree``, ``+degree``)
pad_val (float, optional): Padding value of image. Default: 0.
seg_pad_val (float, optional): Padding value of segmentation map.
Default: 255.
center (tuple[float], optional): Center point (w, h) of the rotation in
the source image. If not specified, the center of the image will be
used. Default: None.
auto_bound (bool): Whether to adjust the image size to cover the whole
rotated image. Default: False
"""
def __init__(self,
prob,
degree,
pad_val=0,
seg_pad_val=255,
center=None,
auto_bound=False):
self.prob = prob
assert prob >= 0 and prob <= 1
if isinstance(degree, (float, int)):
assert degree > 0, f'degree {degree} should be positive'
self.degree = (-degree, degree)
else:
self.degree = degree
assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
f'tuple of (min, max)'
self.pal_val = pad_val
self.seg_pad_val = seg_pad_val
self.center = center
self.auto_bound = auto_bound
@cache_randomness
def generate_degree(self):
return np.random.rand() < self.prob, np.random.uniform(
min(*self.degree), max(*self.degree))
def transform(self, results: dict) -> dict:
"""Call function to rotate image, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Rotated results.
"""
rotate, degree = self.generate_degree()
if rotate:
# rotate image
results['img'] = [
mmcv.imrotate(
img,
angle=degree,
border_value=self.pal_val,
center=self.center,
auto_bound=self.auto_bound) for img in results['img']]
# rotate segs
for key in results.get('seg_fields', []):
results[key] = mmcv.imrotate(
results[key],
angle=degree,
border_value=self.seg_pad_val,
center=self.center,
auto_bound=self.auto_bound,
interpolation='nearest')
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, ' \
f'degree={self.degree}, ' \
f'pad_val={self.pal_val}, ' \
f'seg_pad_val={self.seg_pad_val}, ' \
f'center={self.center}, ' \
f'auto_bound={self.auto_bound})'
return repr_str
@TRANSFORMS.register_module()
class MultiImgRGB2Gray(BaseTransform):
"""Convert RGB image to grayscale image.
Required Keys:
- img
Modified Keys:
- img
- img_shape
This transform calculate the weighted mean of input image channels with
``weights`` and then expand the channels to ``out_channels``. When
``out_channels`` is None, the number of output channels is the same as
input channels.
Args:
out_channels (int): Expected number of output channels after
transforming. Default: None.
weights (tuple[float]): The weights to calculate the weighted mean.
Default: (0.299, 0.587, 0.114).
"""
def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
assert out_channels is None or out_channels > 0
self.out_channels = out_channels
assert isinstance(weights, tuple)
for item in weights:
assert isinstance(item, (float, int))
self.weights = weights
def transform(self, results: dict) -> dict:
"""Call function to convert RGB image to grayscale image.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with grayscale image.
"""
def _rgb2gray(img):
assert len(img.shape) == 3
assert img.shape[2] == len(self.weights)
weights = np.array(self.weights).reshape((1, 1, -1))
img = (img * weights).sum(2, keepdims=True)
if self.out_channels is None:
img = img.repeat(weights.shape[2], axis=2)
else:
img = img.repeat(self.out_channels, axis=2)
return img
imgs = [_rgb2gray(img) for img in results['img']]
results['img'] = imgs
results['img_shape'] = imgs[0].shape
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(out_channels={self.out_channels}, ' \
f'weights={self.weights})'
return repr_str
@TRANSFORMS.register_module()
class MultiImgAdjustGamma(BaseTransform):
"""Using gamma correction to process the image.
Required Keys:
- img
Modified Keys:
- img
Args:
gamma (float or int): Gamma value used in gamma correction.
Default: 1.0.
"""
def __init__(self, gamma=1.0):
assert isinstance(gamma, float) or isinstance(gamma, int)
assert gamma > 0
self.gamma = gamma
inv_gamma = 1.0 / gamma
self.table = np.array([(i / 255.0)**inv_gamma * 255
for i in np.arange(256)]).astype('uint8')
def transform(self, results: dict) -> dict:
"""Call function to process the image with gamma correction.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Processed results.
"""
results['img'] = [
mmcv.lut_transform(
np.array(img, dtype=np.uint8), self.table) for img in results['img']
]
return results
def __repr__(self):
return self.__class__.__name__ + f'(gamma={self.gamma})'
@TRANSFORMS.register_module()
class MultiImgPhotoMetricDistortion(BaseTransform):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
Required Keys:
- img
Modified Keys:
- img
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
consistent_contrast_mode (bool): Whether to
keep the contrast mode consistent.
"""
def __init__(self,
brightness_delta: int = 32,
contrast_range: Sequence[float] = (0.5, 1.5),
saturation_range: Sequence[float] = (0.5, 1.5),
hue_delta: int = 18,
consistent_contrast_mode: bool = False):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
self.consistent_contrast_mode = consistent_contrast_mode
def convert(self,
img: np.ndarray,
alpha: int = 1,
beta: int = 0) -> np.ndarray:
"""Multiple with alpha and add beat with clip.
Args:
img (np.ndarray): The input image.
alpha (int): Image weights, change the contrast/saturation
of the image. Default: 1
beta (int): Image bias, change the brightness of the
image. Default: 0
Returns:
np.ndarray: The transformed image.
"""
img = img.astype(np.float32) * alpha + beta
img = np.clip(img, 0, 255)
return img.astype(np.uint8)
def brightness(self, img: np.ndarray) -> np.ndarray:
"""Brightness distortion.
Args:
img (np.ndarray): The input image.
Returns:
np.ndarray: Image after brightness change.
"""
if random.randint(2):
return self.convert(
img,
beta=random.uniform(-self.brightness_delta,
self.brightness_delta))
return img
def contrast(self, img: np.ndarray) -> np.ndarray:
"""Contrast distortion.
Args:
img (np.ndarray): The input image.
Returns:
np.ndarray: Image after contrast change.
"""
if random.randint(2):
return self.convert(
img,
alpha=random.uniform(self.contrast_lower, self.contrast_upper))
return img
def saturation(self, img: np.ndarray) -> np.ndarray:
"""Saturation distortion.
Args:
img (np.ndarray): The input image.
Returns:
np.ndarray: Image after saturation change.
"""
if random.randint(2):
img = mmcv.bgr2hsv(img)
img[:, :, 1] = self.convert(
img[:, :, 1],
alpha=random.uniform(self.saturation_lower,
self.saturation_upper))
img = mmcv.hsv2bgr(img)
return img
def hue(self, img: np.ndarray) -> np.ndarray:
"""Hue distortion.
Args:
img (np.ndarray): The input image.
Returns:
np.ndarray: Image after hue change.
"""
if random.randint(2):
img = mmcv.bgr2hsv(img)
img[:, :,
0] = (img[:, :, 0].astype(int) +
random.randint(-self.hue_delta, self.hue_delta)) % 180
img = mmcv.hsv2bgr(img)
return img
def transform(self, results: dict) -> dict:
"""Transform function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
def _photo_metric_distortion(img, contrast_mode=None):
# random brightness
img = self.brightness(img)
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = contrast_mode or random.randint(2)
if mode == 1:
img = self.contrast(img)
# random saturation
img = self.saturation(img)
# random hue
img = self.hue(img)
# random contrast
if mode == 0:
img = self.contrast(img)
return img
contrast_mode = random.randint(2) \
if self.consistent_contrast_mode else None
results['img'] = [_photo_metric_distortion(img, contrast_mode) \
for img in results['img']]
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(brightness_delta={self.brightness_delta}, '
f'contrast_range=({self.contrast_lower}, '
f'{self.contrast_upper}), '
f'saturation_range=({self.saturation_lower}, '
f'{self.saturation_upper}), '
f'hue_delta={self.hue_delta}), '
f'consistent_contrast_mode={self.consistent_contrast_mode}')
return repr_str
@TRANSFORMS.register_module()
class MultiImgRandomCutOut(BaseTransform):
"""CutOut operation.
Randomly drop some regions of image used in
`Cutout <https://arxiv.org/abs/1708.04552>`_.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- gt_seg_map
Args:
prob (float): cutout probability.
n_holes (int | tuple[int, int]): Number of regions to be dropped.
If it is given as a list, number of holes will be randomly
selected from the closed interval [`n_holes[0]`, `n_holes[1]`].
cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate
shape of dropped regions. It can be `tuple[int, int]` to use a
fixed cutout shape, or `list[tuple[int, int]]` to randomly choose
shape from the list.
cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The
candidate ratio of dropped regions. It can be `tuple[float, float]`
to use a fixed ratio or `list[tuple[float, float]]` to randomly
choose ratio from the list. Please note that `cutout_shape`
and `cutout_ratio` cannot be both given at the same time.
fill_in (tuple[float, float, float] | tuple[int, int, int]): The value
of pixel to fill in the dropped regions. Default: (0, 0, 0).
seg_fill_in (int): The labels of pixel to fill in the dropped regions.
If seg_fill_in is None, skip. Default: None.
"""
def __init__(self,
prob,
n_holes,
cutout_shape=None,
cutout_ratio=None,
fill_in=(0, 0, 0),
seg_fill_in=None):
assert 0 <= prob and prob <= 1
assert (cutout_shape is None) ^ (cutout_ratio is None), \
'Either cutout_shape or cutout_ratio should be specified.'
assert (isinstance(cutout_shape, (list, tuple))
or isinstance(cutout_ratio, (list, tuple)))
if isinstance(n_holes, tuple):
assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1]
else:
n_holes = (n_holes, n_holes)
if seg_fill_in is not None:
assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in
and seg_fill_in <= 255)
self.prob = prob
self.n_holes = n_holes
self.fill_in = fill_in
self.seg_fill_in = seg_fill_in
self.with_ratio = cutout_ratio is not None
self.candidates = cutout_ratio if self.with_ratio else cutout_shape
if not isinstance(self.candidates, list):
self.candidates = [self.candidates]
@cache_randomness
def do_cutout(self):
return np.random.rand() < self.prob
@cache_randomness
def generate_patches(self, results):
cutout = self.do_cutout()
h, w, _ = results['img'][0].shape
if cutout:
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
else:
n_holes = 0
x1_lst = []
y1_lst = []
index_lst = []
for _ in range(n_holes):
x1_lst.append(np.random.randint(0, w))
y1_lst.append(np.random.randint(0, h))
index_lst.append(np.random.randint(0, len(self.candidates)))
return cutout, n_holes, x1_lst, y1_lst, index_lst
def transform(self, results: dict) -> dict:
"""Call function to drop some regions of image."""
cutout, n_holes, x1_lst, y1_lst, index_lst = self.generate_patches(
results)
if cutout:
h, w, c = results['img'][0].shape
for i in range(n_holes):
x1 = x1_lst[i]
y1 = y1_lst[i]
index = index_lst[i]
if not self.with_ratio:
cutout_w, cutout_h = self.candidates[index]
else:
cutout_w = int(self.candidates[index][0] * w)
cutout_h = int(self.candidates[index][1] * h)
x2 = np.clip(x1 + cutout_w, 0, w)
y2 = np.clip(y1 + cutout_h, 0, h)
for idx in range(len(results['img'])):
results['img'][idx][y1:y2, x1:x2, :] = self.fill_in
if self.seg_fill_in is not None:
for key in results.get('seg_fields', []):
results[key][y1:y2, x1:x2] = self.seg_fill_in
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'n_holes={self.n_holes}, '
repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio
else f'cutout_shape={self.candidates}, ')
repr_str += f'fill_in={self.fill_in}, '
repr_str += f'seg_fill_in={self.seg_fill_in})'
return repr_str
@TRANSFORMS.register_module()
class MultiImgRandomRotFlip(BaseTransform):
"""Rotate and flip the image & seg or just rotate the image & seg.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- gt_seg_map
Args:
rotate_prob (float): The probability of rotate image.
flip_prob (float): The probability of rotate&flip image.
degree (float, tuple[float]): Range of degrees to select from. If
degree is a number instead of tuple like (min, max),
the range of degree will be (``-degree``, ``+degree``)
pad_val (float, optional): Padding value of image. Default: 0.
seg_pad_val (float, optional): Padding value of segmentation map.
Default: 255.
"""
def __init__(self,
rotate_prob=0.5,
flip_prob=0.5,
degree=(-20, 20),
pad_val=0,
seg_pad_val=255):
self.rotate_prob = rotate_prob
self.flip_prob = flip_prob
self.pad_val = pad_val
self.seg_pad_val = seg_pad_val
assert 0 <= rotate_prob <= 1 and 0 <= flip_prob <= 1
if isinstance(degree, (float, int)):
assert degree > 0, f'degree {degree} should be positive'
self.degree = (-degree, degree)
else:
self.degree = degree
assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
f'tuple of (min, max)'
def random_rot_flip(self, results: dict) -> dict:
k = np.random.randint(0, 4)
results['img'] = [np.rot90(img, k) for img in results['img']]
for key in results.get('seg_fields', []):
results[key] = np.rot90(results[key], k)
axis = np.random.randint(0, 2)
results['img'] = [
np.flip(img, axis=axis).copy() for img in results['img']]
for key in results.get('seg_fields', []):
results[key] = np.flip(results[key], axis=axis).copy()
return results
def random_rotate(self, results: dict) -> dict:
angle = np.random.uniform(min(*self.degree), max(*self.degree))
results['img'] = [
mmcv.imrotate(img, angle=angle,
border_value=self.pad_val) for img in results['img']]
for key in results.get('seg_fields', []):
results[key] = mmcv.imrotate(results[key],
angle=angle,
border_value=self.seg_pad_val,
interpolation='nearest')
return results
def transform(self, results: dict) -> dict:
"""Call function to rotate or rotate & flip image, semantic
segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Rotated or rotated & flipped results.
"""
rotate_flag = 0
if random.random() < self.rotate_prob:
results = self.random_rotate(results)
rotate_flag = 1
if random.random() < self.flip_prob and rotate_flag == 0:
results = self.random_rot_flip(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(rotate_prob={self.rotate_prob}, ' \
f'flip_prob={self.flip_prob}, ' \
f'degree={self.degree})'
return repr_str
@TRANSFORMS.register_module()
class MultiImgResizeShortestEdge(BaseTransform):
"""Resize the image and mask while keeping the aspect ratio unchanged.
Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501
Copyright (c) Facebook, Inc. and its affiliates.
Licensed under the Apache-2.0 License
This transform attempts to scale the shorter edge to the given
`scale`, as long as the longer edge does not exceed `max_size`.
If `max_size` is reached, then downscale so that the longer
edge does not exceed `max_size`.
Required Keys:
- img
- gt_seg_map (optional)
Modified Keys:
- img
- img_shape
- gt_seg_map (optional)
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
scale (Union[int, Tuple[int, int]]): The target short edge length.
If it's tuple, will select the min value as the short edge length.
max_size (int): The maximum allowed longest edge length.
"""
def __init__(self, scale: Union[int, Tuple[int, int]],
max_size: int) -> None:
super().__init__()
self.scale = scale
self.max_size = max_size
# Create a empty Resize object
self.resize = TRANSFORMS.build({
'type': 'MultiImgResize',
'scale': 0,
'keep_ratio': True
})
def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]:
"""Compute the target image shape with the given `short_edge_length`.
Args:
img (np.ndarray): The input image.
short_edge_length (Union[int, Tuple[int, int]]): The target short
edge length. If it's tuple, will select the min value as the
short edge length.
"""
h, w = img.shape[:2]
if isinstance(short_edge_length, int):
size = short_edge_length * 1.0
elif isinstance(short_edge_length, tuple):
size = min(short_edge_length) * 1.0
scale = size / min(h, w)
if h < w:
new_h, new_w = size, scale * w
else:
new_h, new_w = scale * h, size
if max(new_h, new_w) > self.max_size:
scale = self.max_size * 1.0 / max(new_h, new_w)
new_h *= scale
new_w *= scale
new_h = int(new_h + 0.5)
new_w = int(new_w + 0.5)
return (new_w, new_h)
def transform(self, results: Dict) -> Dict:
self.resize.scale = self._get_output_shape(results['img'], self.scale)
return self.resize(results)
@TRANSFORMS.register_module()
class MultiImgExchangeTime(BaseTransform):
"""Exchange images of different times.
Args:
prob (float): probability of applying the transform. Default: 0.5.
"""
def __init__(self,
prob: float = 0.5) -> None:
assert 0 <= prob and prob <= 1
self.prob = prob
def transform(self, results: dict) -> dict:
"""Call function to exchange images ."""
exchange = True if np.random.rand() < self.prob else False
if exchange:
results['img'].reverse() # list.reverse()
if 'gt_seg_map_from' in results['seg_fields'] and \
'gt_seg_map_to' in results['seg_fields']:
results['gt_seg_map_from'], results['gt_seg_map_to'] = \
results['gt_seg_map_to'], results['gt_seg_map_from']
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
return repr_str
@TRANSFORMS.register_module()
class MultiImgResize(BaseTransform):
"""Resize images & seg.
This transform resizes the input image according to ``scale`` or
``scale_factor``. Bboxes, seg map and keypoints are then resized with the
same scale factor.
if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to
resize.
Required Keys:
- img
- gt_seg_map (optional)
Modified Keys:
- img
- gt_seg_map
- img_shape
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
scale (int or tuple): Images scales for resizing. Defaults to None
scale_factor (float or tuple[float]): Scale factors for resizing.
Defaults to None.
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image. Defaults to False.
clip_object_border (bool): Whether to clip the objects
outside the border of the image. In some dataset like MOT17, the gt
bboxes are allowed to cross the border of images. Therefore, we
don't need to clip the gt bboxes in these cases. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def __init__(self,
scale: Optional[Union[int, Tuple[int, int]]] = None,
scale_factor: Optional[Union[float, Tuple[float,
float]]] = None,
keep_ratio: bool = False,
clip_object_border: bool = True,
backend: str = 'cv2',
interpolation='bilinear') -> None:
assert scale is not None or scale_factor is not None, (
'`scale` and'
'`scale_factor` can not both be `None`')
if scale is None:
self.scale = None
else:
if isinstance(scale, int):
self.scale = (scale, scale)
else:
self.scale = scale
self.backend = backend
self.interpolation = interpolation
self.keep_ratio = keep_ratio
self.clip_object_border = clip_object_border
if scale_factor is None:
self.scale_factor = None
elif isinstance(scale_factor, float):
self.scale_factor = (scale_factor, scale_factor)
elif isinstance(scale_factor, tuple):
assert (len(scale_factor)) == 2
self.scale_factor = scale_factor
else:
raise TypeError(
f'expect scale_factor is float or Tuple(float), but'
f'get {type(scale_factor)}')
def _resize_img(self, results: dict) -> None:
"""Resize images with ``results['scale']``."""
if results.get('img', None) is not None:
res_imgs = []
for img in results['img']:
if self.keep_ratio:
img, scale_factor = mmcv.imrescale(
img,
results['scale'],
interpolation=self.interpolation,
return_scale=True,
backend=self.backend)
# the w_scale and h_scale has minor difference
# a real fix should be done in the mmcv.imrescale in the future
new_h, new_w = img.shape[:2]
h, w = img.shape[:2]
w_scale = new_w / w
h_scale = new_h / h
else:
img, w_scale, h_scale = mmcv.imresize(
img,
results['scale'],
interpolation=self.interpolation,
return_scale=True,
backend=self.backend)
res_imgs.append(img)
results['img'] = res_imgs
results['img_shape'] = res_imgs[0].shape[:2]
results['scale_factor'] = (w_scale, h_scale)
results['keep_ratio'] = self.keep_ratio
def _resize_seg(self, results: dict) -> None:
"""Resize semantic segmentation map with ``results['scale']``."""
for key in results.get('seg_fields', []):
if self.keep_ratio:
gt_seg = mmcv.imrescale(
results[key],
results['scale'],
interpolation='nearest',
backend=self.backend)
else:
gt_seg = mmcv.imresize(
results[key],
results['scale'],
interpolation='nearest',
backend=self.backend)
results[key] = gt_seg
def transform(self, results: dict) -> dict:
"""Transform function to resize images, semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img', 'gt_seg_map',
'scale', 'scale_factor', 'img_shape',
and 'keep_ratio' keys are updated in result dict.
"""
if self.scale:
results['scale'] = self.scale
else:
img_shape = results['img'][0].shape[:2]
results['scale'] = _scale_size(img_shape[::-1],
self.scale_factor) # type: ignore
self._resize_img(results)
self._resize_seg(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(scale={self.scale}, '
repr_str += f'scale_factor={self.scale_factor}, '
repr_str += f'keep_ratio={self.keep_ratio}, '
repr_str += f'clip_object_border={self.clip_object_border}), '
repr_str += f'backend={self.backend}), '
repr_str += f'interpolation={self.interpolation})'
return repr_str
@TRANSFORMS.register_module()
class MultiImgRandomResize(BaseTransform):
"""Random resize images.
How to choose the target scale to resize the image will follow the rules
below:
- if ``scale`` is a sequence of tuple
.. math::
target\\_scale[0] \\sim Uniform([scale[0][0], scale[1][0]])
.. math::
target\\_scale[1] \\sim Uniform([scale[0][1], scale[1][1]])
Following the resize order of weight and height in cv2, ``scale[i][0]``
is for width, and ``scale[i][1]`` is for height.
- if ``scale`` is a tuple
.. math::
target\\_scale[0] \\sim Uniform([ratio\\_range[0], ratio\\_range[1]])
* scale[0]
.. math::
target\\_scale[0] \\sim Uniform([ratio\\_range[0], ratio\\_range[1]])
* scale[1]
Following the resize order of weight and height in cv2, ``ratio_range[0]``
is for width, and ``ratio_range[1]`` is for height.
- if ``keep_ratio`` is True, the minimum value of ``target_scale`` will be
used to set the shorter side and the maximum value will be used to
set the longer side.
- if ``keep_ratio`` is False, the value of ``target_scale`` will be used to
reisze the width and height accordingly.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- gt_seg_map
- img_shape
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
scale (tuple or Sequence[tuple]): Images scales for resizing.
Defaults to None.
ratio_range (tuple[float], optional): (min_ratio, max_ratio).
Defaults to None.
resize_type (str): The type of resize class to use. Defaults to
"Resize".
**resize_kwargs: Other keyword arguments for the ``resize_type``.
Note:
By defaults, the ``resize_type`` is "Resize", if it's not overwritten
by your registry, it indicates the :class:`mmcv.Resize`. And therefore,
``resize_kwargs`` accepts any keyword arguments of it, like
``keep_ratio``, ``interpolation`` and so on.
If you want to use your custom resize class, the class should accept
``scale`` argument and have ``scale`` attribution which determines the
resize shape.
"""
def __init__(
self,
scale: Union[Tuple[int, int], Sequence[Tuple[int, int]]],
ratio_range: Tuple[float, float] = None,
resize_type: str = 'MultiImgResize',
**resize_kwargs,
) -> None:
self.scale = scale
self.ratio_range = ratio_range
self.resize_cfg = dict(type=resize_type, **resize_kwargs)
# create a empty Reisize object
self.resize = TRANSFORMS.build({'scale': 0, **self.resize_cfg})
@staticmethod
def _random_sample(scales: Sequence[Tuple[int, int]]) -> tuple:
"""Private function to randomly sample a scale from a list of tuples.
Args:
scales (list[tuple]): Images scale range for sampling.
There must be two tuples in scales, which specify the lower
and upper bound of image scales.
Returns:
tuple: The targeted scale of the image to be resized.
"""
assert is_list_of(scales, tuple) and len(scales) == 2
scale_0 = [scales[0][0], scales[1][0]]
scale_1 = [scales[0][1], scales[1][1]]
edge_0 = np.random.randint(min(scale_0), max(scale_0) + 1)
edge_1 = np.random.randint(min(scale_1), max(scale_1) + 1)
scale = (edge_0, edge_1)
return scale
@staticmethod
def _random_sample_ratio(scale: tuple, ratio_range: Tuple[float,
float]) -> tuple:
"""Private function to randomly sample a scale from a tuple.
A ratio will be randomly sampled from the range specified by
``ratio_range``. Then it would be multiplied with ``scale`` to
generate sampled scale.
Args:
scale (tuple): Images scale base to multiply with ratio.
ratio_range (tuple[float]): The minimum and maximum ratio to scale
the ``scale``.
Returns:
tuple: The targeted scale of the image to be resized.
"""
assert isinstance(scale, tuple) and len(scale) == 2
min_ratio, max_ratio = ratio_range
assert min_ratio <= max_ratio
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
scale = int(scale[0] * ratio), int(scale[1] * ratio)
return scale
@cache_randomness
def _random_scale(self) -> tuple:
"""Private function to randomly sample an scale according to the type
of ``scale``.
Returns:
tuple: The targeted scale of the image to be resized.
"""
if is_tuple_of(self.scale, int):
assert self.ratio_range is not None and len(self.ratio_range) == 2
scale = self._random_sample_ratio(
self.scale, # type: ignore
self.ratio_range)
elif is_seq_of(self.scale, tuple):
scale = self._random_sample(self.scale) # type: ignore
else:
raise NotImplementedError('Do not support sampling function '
f'for "{self.scale}"')
return scale
def transform(self, results: dict) -> dict:
"""Transform function to resize images, bounding boxes, semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, ``img``, ``gt_semantic_seg``,
``scale``, ``scale_factor``, ``img_shape``, and
``keep_ratio`` keys are updated in result dict.
"""
results['scale'] = self._random_scale()
self.resize.scale = results['scale']
results = self.resize(results)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(scale={self.scale}, '
repr_str += f'ratio_range={self.ratio_range}, '
repr_str += f'resize_cfg={self.resize_cfg})'
return repr_str
@TRANSFORMS.register_module()
class MultiImgNormalize(BaseTransform):
"""Normalize the images.
Required Keys:
- img
Modified Keys:
- img
Added Keys:
- img_norm_cfg
- mean
- std
- to_rgb
Args:
mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB before
normlizing the image. If ``to_rgb=True``, the order of mean and std
should be RGB. If ``to_rgb=False``, the order of mean and std
should be the same order of the image. Defaults to True.
"""
def __init__(self,
mean: Sequence[Union[int, float]],
std: Sequence[Union[int, float]],
to_rgb: bool = True) -> None:
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
def transform(self, results: dict) -> dict:
"""Function to normalize images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Normalized results, key 'img_norm_cfg' key is added in to
result dict.
"""
results['img'] = [
mmcv.imnormalize(img, self.mean, self.std, self.to_rgb)
for img in results['img']]
results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'
return repr_str
@TRANSFORMS.register_module()
class MultiImgRandomFlip(BaseTransform):
"""Flip the image & segmentation map. Added or Updated
keys: flip, flip_direction, img, gt_bboxes, gt_seg_map, and
gt_keypoints. There are 3 flip modes:
- ``prob`` is float, ``direction`` is string: the image will be
``direction``ly flipped with probability of ``prob`` .
E.g., ``prob=0.5``, ``direction='horizontal'``,
then image will be horizontally flipped with probability of 0.5.
- ``prob`` is float, ``direction`` is list of string: the image will
be ``direction[i]``ly flipped with probability of
``prob/len(direction)``.
E.g., ``prob=0.5``, ``direction=['horizontal', 'vertical']``,
then image will be horizontally flipped with probability of 0.25,
vertically with probability of 0.25.
- ``prob`` is list of float, ``direction`` is list of string:
given ``len(prob) == len(direction)``, the image will
be ``direction[i]``ly flipped with probability of ``prob[i]``.
E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal',
'vertical']``, then image will be horizontally flipped with
probability of 0.3, vertically with probability of 0.5.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- gt_seg_map
Added Keys:
- flip
- flip_direction
Args:
prob (float | list[float], optional): The flipping probability.
Defaults to None.
direction(str | list[str]): The flipping direction. Options
If input is a list, the length must equal ``prob``. Each
element in ``prob`` indicates the flip probability of
corresponding direction. Defaults to 'horizontal'.
"""
def __init__(self,
prob: Optional[Union[float, Iterable[float]]] = None,
direction: Union[str, Sequence[Optional[str]]] = 'horizontal') -> None:
if isinstance(prob, list):
assert is_list_of(prob, float)
assert 0 <= sum(prob) <= 1
elif isinstance(prob, float):
assert 0 <= prob <= 1
else:
raise ValueError(f'probs must be float or list of float, but \
got `{type(prob)}`.')
self.prob = prob
valid_directions = ['horizontal', 'vertical', 'diagonal']
if isinstance(direction, str):
assert direction in valid_directions
elif isinstance(direction, list):
assert is_list_of(direction, str)
assert set(direction).issubset(set(valid_directions))
else:
raise ValueError(f'direction must be either str or list of str, \
but got `{type(direction)}`.')
self.direction = direction
if isinstance(prob, list):
assert len(prob) == len(self.direction)
@cache_randomness
def _choose_direction(self) -> str:
"""Choose the flip direction according to `prob` and `direction`"""
if isinstance(self.direction,
Sequence) and not isinstance(self.direction, str):
# None means non-flip
direction_list: list = list(self.direction) + [None]
elif isinstance(self.direction, str):
# None means non-flip
direction_list = [self.direction, None]
if isinstance(self.prob, list):
non_prob: float = 1 - sum(self.prob)
prob_list = self.prob + [non_prob]
elif isinstance(self.prob, float):
non_prob = 1. - self.prob
# exclude non-flip
single_ratio = self.prob / (len(direction_list) - 1)
prob_list = [single_ratio] * (len(direction_list) - 1) + [non_prob]
cur_dir = np.random.choice(direction_list, p=prob_list)
return cur_dir
def transform(self, results: dict) -> dict:
"""Transform function to flip images, semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Flipped results, 'img', 'gt_seg_map',
'flip', and 'flip_direction' keys are
updated in result dict.
"""
cur_dir = self._choose_direction()
if cur_dir is None:
results['flip'] = False
results['flip_direction'] = None
else:
results['flip'] = True
results['flip_direction'] = cur_dir
# flip image
results['img'] = [
mmcv.imflip(img, direction=results['flip_direction'])
for img in results['img']
]
# flip segs
for key in results.get('seg_fields', []):
# use copy() to make numpy stride positive
results[key] = mmcv.imflip(
results[key], direction=results['flip_direction']).copy()
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'direction={self.direction})'
return repr_str
@TRANSFORMS.register_module()
class MultiImgPad(BaseTransform):
"""Pad the image & segmentation map.
There are three padding modes: (1) pad to a fixed size and (2) pad to the
minimum size that is divisible by some number. and (3)pad to square. Also,
pad to square and pad to the minimum size can be used as the same time.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- gt_seg_map
- img_shape
Added Keys:
- pad_shape
- pad_fixed_size
- pad_size_divisor
Args:
size (tuple, optional): Fixed padding size.
Expected padding shape (w, h). Defaults to None.
size_divisor (int, optional): The divisor of padded size. Defaults to
None.
pad_to_square (bool): Whether to pad the image into a square.
Currently only used for YOLOX. Defaults to False.
pad_val (Number | dict[str, Number], optional): Padding value for if
the pad_mode is "constant". If it is a single number, the value
to pad the image is the number and to pad the semantic
segmentation map is 255. If it is a dict, it should have the
following keys:
- img: The value to pad the image.
- seg: The value to pad the semantic segmentation map.
Defaults to dict(img=0, seg=255).
padding_mode (str): Type of padding. Should be: constant, edge,
reflect or symmetric. Defaults to 'constant'.
- constant: pads with a constant value, this value is specified
with pad_val.
- edge: pads with the last value at the edge of the image.
- reflect: pads with reflection of image without repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with 2
elements on both sides in reflect mode will result in
[3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: pads with reflection of image repeating the last value
on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3]
"""
def __init__(self,
size: Optional[Tuple[int, int]] = None,
size_divisor: Optional[int] = None,
pad_to_square: bool = False,
pad_val: Union[int, float, dict] = dict(img=0, seg=255),
padding_mode: str = 'constant') -> None:
self.size = size
self.size_divisor = size_divisor
if isinstance(pad_val, int):
pad_val = dict(img=pad_val, seg=255)
assert isinstance(pad_val, dict), 'pad_val '
self.pad_val = pad_val
self.pad_to_square = pad_to_square
if pad_to_square:
assert size is None, \
'The size and size_divisor must be None ' \
'when pad2square is True'
else:
assert size is not None or size_divisor is not None, \
'only one of size and size_divisor should be valid'
assert size is None or size_divisor is None
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
self.padding_mode = padding_mode
def _pad_img(self, results: dict) -> None:
"""Pad images according to ``self.size``."""
pad_val = self.pad_val.get('img', 0)
size = None
if self.pad_to_square:
max_size = max(results['img'][0].shape[:2])
size = (max_size, max_size)
if self.size_divisor is not None:
if size is None:
size = (results['img'][0].shape[0], results['img'].shape[1])
pad_h = int(np.ceil(
size[0] / self.size_divisor)) * self.size_divisor
pad_w = int(np.ceil(
size[1] / self.size_divisor)) * self.size_divisor
size = (pad_h, pad_w)
elif self.size is not None:
size = self.size[::-1]
if isinstance(pad_val, int) and results['img'][0].ndim == 3:
pad_val = tuple(pad_val for _ in range(results['img'][0].shape[2]))
padded_imgs = [
mmcv.impad(
img,
shape=size,
pad_val=pad_val,
padding_mode=self.padding_mode) for img in results['img']]
results['img'] = padded_imgs
results['pad_shape'] = padded_imgs[0].shape
results['pad_fixed_size'] = self.size
results['pad_size_divisor'] = self.size_divisor
results['img_shape'] = padded_imgs[0].shape[:2]
def _pad_seg(self, results: dict) -> None:
"""Pad semantic segmentation map according to
``results['pad_shape']``."""
pad_val = self.pad_val.get('seg', 255)
for key in results.get('seg_fields', []):
results[key] = mmcv.impad(
results[key],
shape=results['pad_shape'][:2],
pad_val=pad_val,
padding_mode=self.padding_mode)
def transform(self, results: dict) -> dict:
"""Call function to pad images, masks, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
"""
self._pad_img(results)
self._pad_seg(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(size={self.size}, '
repr_str += f'size_divisor={self.size_divisor}, '
repr_str += f'pad_to_square={self.pad_to_square}, '
repr_str += f'pad_val={self.pad_val}), '
repr_str += f'padding_mode={self.padding_mode})'
return repr_str
@TRANSFORMS.register_module()
class MultiImgAlbu(BaseTransform):
"""Albumentation augmentation. Adds custom transformations from
Albumentations library. Please, visit
`https://albumentations.readthedocs.io` to get more information. An example
of ``transforms`` is as followed:
.. code-block::
[
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=0.5),
dict(
type='RandomBrightnessContrast',
brightness_limit=[0.1, 0.3],
contrast_limit=[0.1, 0.3],
p=0.2),
dict(type='ChannelShuffle', p=0.1),
dict(
type='OneOf',
transforms=[
dict(type='Blur', blur_limit=3, p=1.0),
dict(type='MedianBlur', blur_limit=3, p=1.0)
],
p=0.1),
]
Args:
transforms (list[dict]): A list of albu transformations
keymap (dict): Contains {'input key':'albumentation-style key'}
update_pad_shape (bool): Whether update final shape.
additional_targets (dict): Dict with keys - new target name,
values - old target name. ex: {'image2': 'image'}.
"""
def __init__(self,
transforms: List[dict],
keymap: dict = None,
update_pad_shape: bool = False,
additional_targets: dict = None) -> None:
# Args will be modified later, copying it will be safer
transforms = copy.deepcopy(transforms)
if keymap is not None:
keymap = copy.deepcopy(keymap)
self.transforms = transforms
self.filter_lost_elements = False
self.update_pad_shape = update_pad_shape
self.additional_targets = additional_targets
self.aug = Compose([self.albu_builder(t) for t in self.transforms], \
additional_targets=self.additional_targets)
if not keymap:
self.keymap_to_albu = {'img': 'image', 'gt_semantic_seg': 'mask'}
else:
self.keymap_to_albu = keymap
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
def albu_builder(self, cfg):
"""Import a module from albumentations.
It inherits some of :func:`build_from_cfg` logic.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg
args = cfg.copy()
obj_type = args.pop('type')
if is_str(obj_type):
obj_cls = getattr(albumentations, obj_type)
else:
raise TypeError(f'type must be str, but got {type(obj_type)}')
if 'transforms' in args:
args['transforms'] = [
self.albu_builder(transform)
for transform in args['transforms']
]
return obj_cls(**args)
@staticmethod
def mapper(d: dict, keymap: dict) -> dict:
"""Dictionary mapper.
Renames keys according to keymap provided.
Args:
d (dict): old dict
keymap (dict): {'old_key':'new_key'}
Returns:
dict: new dict.
"""
updated_dict = {}
for k, v in zip(d.keys(), d.values()):
new_k = keymap.get(k, k)
updated_dict[new_k] = d[k]
if updated_dict.get('image', None) is not None:
updated_dict['image'] = np.concatenate(updated_dict['image'], axis=-1)
if updated_dict.get('img', None) is not None:
updated_dict['img'] = np.split(updated_dict['img'], indices_or_sections=2, axis=-1)
return updated_dict
def transform(self, results: dict) -> dict:
# dict to albumentations format
results = self.mapper(results, self.keymap_to_albu)
results = self.aug(**results)
# back to the original format
results = self.mapper(results, self.keymap_back)
# update final shape
if self.update_pad_shape:
results['pad_shape'] = results['img'][0].shape
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
repr_str += f'(update_pad_shape={self.update_pad_shape}, '
repr_str += f'(additional_targets={self.additional_targets})'
return repr_str