|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
implment AutoAugment, RandAugment |
|
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py and modified for token labeling |
|
""" |
|
import math |
|
import random |
|
import re |
|
|
|
import numpy as np |
|
import PIL |
|
from PIL import Image, ImageEnhance, ImageOps |
|
|
|
_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) |
|
|
|
_FILL = (128, 128, 128) |
|
|
|
_MAX_LEVEL = 10.0 |
|
|
|
_HPARAMS_DEFAULT = dict( |
|
translate_const=250, |
|
img_mean=_FILL, |
|
) |
|
|
|
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) |
|
|
|
|
|
def _interpolation(kwargs): |
|
interpolation = kwargs.pop("resample", Image.BILINEAR) |
|
if isinstance(interpolation, (list, tuple)): |
|
return random.choice(interpolation) |
|
else: |
|
return interpolation |
|
|
|
|
|
def _check_args_tf(kwargs): |
|
if "fillcolor" in kwargs and _PIL_VER < (5, 0): |
|
kwargs.pop("fillcolor") |
|
kwargs["resample"] = _interpolation(kwargs) |
|
|
|
|
|
def shear_x(img, factor, **kwargs): |
|
_check_args_tf(kwargs) |
|
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) |
|
|
|
|
|
def shear_y(img, factor, **kwargs): |
|
_check_args_tf(kwargs) |
|
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) |
|
|
|
|
|
def translate_x_rel(img, pct, **kwargs): |
|
pixels = pct * img.size[0] |
|
_check_args_tf(kwargs) |
|
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) |
|
|
|
|
|
def translate_y_rel(img, pct, **kwargs): |
|
pixels = pct * img.size[1] |
|
_check_args_tf(kwargs) |
|
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) |
|
|
|
|
|
def translate_x_abs(img, pixels, **kwargs): |
|
_check_args_tf(kwargs) |
|
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) |
|
|
|
|
|
def translate_y_abs(img, pixels, **kwargs): |
|
_check_args_tf(kwargs) |
|
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) |
|
|
|
|
|
def rotate(img, degrees, **kwargs): |
|
_check_args_tf(kwargs) |
|
if _PIL_VER >= (5, 2): |
|
return img.rotate(degrees, **kwargs) |
|
elif _PIL_VER >= (5, 0): |
|
w, h = img.size |
|
post_trans = (0, 0) |
|
rotn_center = (w / 2.0, h / 2.0) |
|
angle = -math.radians(degrees) |
|
matrix = [ |
|
round(math.cos(angle), 15), |
|
round(math.sin(angle), 15), |
|
0.0, |
|
round(-math.sin(angle), 15), |
|
round(math.cos(angle), 15), |
|
0.0, |
|
] |
|
|
|
def transform(x, y, matrix): |
|
(a, b, c, d, e, f) = matrix |
|
return a * x + b * y + c, d * x + e * y + f |
|
|
|
matrix[2], matrix[5] = transform( |
|
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix |
|
) |
|
matrix[2] += rotn_center[0] |
|
matrix[5] += rotn_center[1] |
|
return img.transform(img.size, Image.AFFINE, matrix, **kwargs) |
|
else: |
|
return img.rotate(degrees, resample=kwargs["resample"]) |
|
|
|
|
|
def auto_contrast(img, **__): |
|
return ImageOps.autocontrast(img) |
|
|
|
|
|
def invert(img, **__): |
|
return ImageOps.invert(img) |
|
|
|
|
|
def equalize(img, **__): |
|
return ImageOps.equalize(img) |
|
|
|
|
|
def solarize(img, thresh, **__): |
|
return ImageOps.solarize(img, thresh) |
|
|
|
|
|
def solarize_add(img, add, thresh=128, **__): |
|
lut = [] |
|
for i in range(256): |
|
if i < thresh: |
|
lut.append(min(255, i + add)) |
|
else: |
|
lut.append(i) |
|
if img.mode in ("L", "RGB"): |
|
if img.mode == "RGB" and len(lut) == 256: |
|
lut = lut + lut + lut |
|
return img.point(lut) |
|
else: |
|
return img |
|
|
|
|
|
def posterize(img, bits_to_keep, **__): |
|
if bits_to_keep >= 8: |
|
return img |
|
return ImageOps.posterize(img, bits_to_keep) |
|
|
|
|
|
def contrast(img, factor, **__): |
|
return ImageEnhance.Contrast(img).enhance(factor) |
|
|
|
|
|
def color(img, factor, **__): |
|
return ImageEnhance.Color(img).enhance(factor) |
|
|
|
|
|
def brightness(img, factor, **__): |
|
return ImageEnhance.Brightness(img).enhance(factor) |
|
|
|
|
|
def sharpness(img, factor, **__): |
|
return ImageEnhance.Sharpness(img).enhance(factor) |
|
|
|
|
|
def _randomly_negate(v): |
|
"""With 50% prob, negate the value""" |
|
return -v if random.random() > 0.5 else v |
|
|
|
|
|
def _rotate_level_to_arg(level, _hparams): |
|
|
|
level = (level / _MAX_LEVEL) * 30.0 |
|
level = _randomly_negate(level) |
|
return (level,) |
|
|
|
|
|
def _enhance_level_to_arg(level, _hparams): |
|
|
|
return ((level / _MAX_LEVEL) * 1.8 + 0.1,) |
|
|
|
|
|
def _enhance_increasing_level_to_arg(level, _hparams): |
|
|
|
|
|
level = (level / _MAX_LEVEL) * 0.9 |
|
level = 1.0 + _randomly_negate(level) |
|
return (level,) |
|
|
|
|
|
def _shear_level_to_arg(level, _hparams): |
|
|
|
level = (level / _MAX_LEVEL) * 0.3 |
|
level = _randomly_negate(level) |
|
return (level,) |
|
|
|
|
|
def _translate_abs_level_to_arg(level, hparams): |
|
translate_const = hparams["translate_const"] |
|
level = (level / _MAX_LEVEL) * float(translate_const) |
|
level = _randomly_negate(level) |
|
return (level,) |
|
|
|
|
|
def _translate_rel_level_to_arg(level, hparams): |
|
|
|
translate_pct = hparams.get("translate_pct", 0.45) |
|
level = (level / _MAX_LEVEL) * translate_pct |
|
level = _randomly_negate(level) |
|
return (level,) |
|
|
|
|
|
def _posterize_level_to_arg(level, _hparams): |
|
|
|
|
|
|
|
return (int((level / _MAX_LEVEL) * 4),) |
|
|
|
|
|
def _posterize_increasing_level_to_arg(level, hparams): |
|
|
|
|
|
|
|
return (4 - _posterize_level_to_arg(level, hparams)[0],) |
|
|
|
|
|
def _posterize_original_level_to_arg(level, _hparams): |
|
|
|
|
|
|
|
return (int((level / _MAX_LEVEL) * 4) + 4,) |
|
|
|
|
|
def _solarize_level_to_arg(level, _hparams): |
|
|
|
|
|
return (int((level / _MAX_LEVEL) * 256),) |
|
|
|
|
|
def _solarize_increasing_level_to_arg(level, _hparams): |
|
|
|
|
|
return (256 - _solarize_level_to_arg(level, _hparams)[0],) |
|
|
|
|
|
def _solarize_add_level_to_arg(level, _hparams): |
|
|
|
return (int((level / _MAX_LEVEL) * 110),) |
|
|
|
|
|
LEVEL_TO_ARG = { |
|
"AutoContrast": None, |
|
"Equalize": None, |
|
"Invert": None, |
|
"Rotate": _rotate_level_to_arg, |
|
|
|
"Posterize": _posterize_level_to_arg, |
|
"PosterizeIncreasing": _posterize_increasing_level_to_arg, |
|
"PosterizeOriginal": _posterize_original_level_to_arg, |
|
"Solarize": _solarize_level_to_arg, |
|
"SolarizeIncreasing": _solarize_increasing_level_to_arg, |
|
"SolarizeAdd": _solarize_add_level_to_arg, |
|
"Color": _enhance_level_to_arg, |
|
"ColorIncreasing": _enhance_increasing_level_to_arg, |
|
"Contrast": _enhance_level_to_arg, |
|
"ContrastIncreasing": _enhance_increasing_level_to_arg, |
|
"Brightness": _enhance_level_to_arg, |
|
"BrightnessIncreasing": _enhance_increasing_level_to_arg, |
|
"Sharpness": _enhance_level_to_arg, |
|
"SharpnessIncreasing": _enhance_increasing_level_to_arg, |
|
"ShearX": _shear_level_to_arg, |
|
"ShearY": _shear_level_to_arg, |
|
"TranslateX": _translate_abs_level_to_arg, |
|
"TranslateY": _translate_abs_level_to_arg, |
|
"TranslateXRel": _translate_rel_level_to_arg, |
|
"TranslateYRel": _translate_rel_level_to_arg, |
|
} |
|
|
|
|
|
NAME_TO_OP = { |
|
"AutoContrast": auto_contrast, |
|
"Equalize": equalize, |
|
"Invert": invert, |
|
"Rotate": rotate, |
|
"Posterize": posterize, |
|
"PosterizeIncreasing": posterize, |
|
"PosterizeOriginal": posterize, |
|
"Solarize": solarize, |
|
"SolarizeIncreasing": solarize, |
|
"SolarizeAdd": solarize_add, |
|
"Color": color, |
|
"ColorIncreasing": color, |
|
"Contrast": contrast, |
|
"ContrastIncreasing": contrast, |
|
"Brightness": brightness, |
|
"BrightnessIncreasing": brightness, |
|
"Sharpness": sharpness, |
|
"SharpnessIncreasing": sharpness, |
|
"ShearX": shear_x, |
|
"ShearY": shear_y, |
|
"TranslateX": translate_x_abs, |
|
"TranslateY": translate_y_abs, |
|
"TranslateXRel": translate_x_rel, |
|
"TranslateYRel": translate_y_rel, |
|
} |
|
|
|
_RAND_TRANSFORMS = [ |
|
"AutoContrast", |
|
"Equalize", |
|
"Invert", |
|
"Rotate", |
|
"Posterize", |
|
"Solarize", |
|
"SolarizeAdd", |
|
"Color", |
|
"Contrast", |
|
"Brightness", |
|
"Sharpness", |
|
"ShearX", |
|
"ShearY", |
|
"TranslateXRel", |
|
"TranslateYRel", |
|
|
|
] |
|
|
|
|
|
_RAND_INCREASING_TRANSFORMS = [ |
|
"AutoContrast", |
|
"Equalize", |
|
"Invert", |
|
"Rotate", |
|
"PosterizeIncreasing", |
|
"SolarizeIncreasing", |
|
"SolarizeAdd", |
|
"ColorIncreasing", |
|
"ContrastIncreasing", |
|
"BrightnessIncreasing", |
|
"SharpnessIncreasing", |
|
"ShearX", |
|
"ShearY", |
|
"TranslateXRel", |
|
"TranslateYRel", |
|
|
|
] |
|
|
|
|
|
|
|
|
|
_RAND_CHOICE_WEIGHTS_0 = { |
|
"Rotate": 0.3, |
|
"ShearX": 0.2, |
|
"ShearY": 0.2, |
|
"TranslateXRel": 0.1, |
|
"TranslateYRel": 0.1, |
|
"Color": 0.025, |
|
"Sharpness": 0.025, |
|
"AutoContrast": 0.025, |
|
"Solarize": 0.005, |
|
"SolarizeAdd": 0.005, |
|
"Contrast": 0.005, |
|
"Brightness": 0.005, |
|
"Equalize": 0.005, |
|
"Posterize": 0, |
|
"Invert": 0, |
|
} |
|
|
|
|
|
def _select_rand_weights(weight_idx=0, transforms=None): |
|
transforms = transforms or _RAND_TRANSFORMS |
|
assert weight_idx == 0 |
|
rand_weights = _RAND_CHOICE_WEIGHTS_0 |
|
probs = [rand_weights[k] for k in transforms] |
|
probs /= np.sum(probs) |
|
return probs |
|
|
|
|
|
class AugmentOp: |
|
def __init__(self, name, prob=0.5, magnitude=10, hparams=None): |
|
hparams = hparams or _HPARAMS_DEFAULT |
|
self.name = name |
|
self.aug_fn = NAME_TO_OP[name] |
|
self.level_fn = LEVEL_TO_ARG[name] |
|
self.prob = prob |
|
self.magnitude = magnitude |
|
self.hparams = hparams.copy() |
|
self.kwargs = dict( |
|
fillcolor=hparams["img_mean"] if "img_mean" in hparams else _FILL, |
|
resample=hparams["interpolation"] |
|
if "interpolation" in hparams |
|
else _RANDOM_INTERPOLATION, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.magnitude_std = self.hparams.get("magnitude_std", 0) |
|
|
|
def __call__(self, img): |
|
if self.prob < 1.0 and random.random() > self.prob: |
|
return img |
|
magnitude = self.magnitude |
|
if self.magnitude_std and self.magnitude_std > 0: |
|
magnitude = random.gauss(magnitude, self.magnitude_std) |
|
magnitude = min(_MAX_LEVEL, max(0, magnitude)) |
|
level_args = ( |
|
self.level_fn(magnitude, self.hparams) |
|
if self.level_fn is not None |
|
else tuple() |
|
) |
|
imgs = self.aug_fn(img, *level_args, **self.kwargs) |
|
|
|
return imgs |
|
|
|
|
|
def rand_augment_ops(magnitude=10, hparams=None, transforms=None): |
|
hparams = hparams or _HPARAMS_DEFAULT |
|
transforms = transforms or _RAND_TRANSFORMS |
|
return [ |
|
AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) |
|
for name in transforms |
|
] |
|
|
|
|
|
class RandAugment: |
|
""" |
|
Apply RandAug on image |
|
""" |
|
|
|
def __init__(self, ops, num_layers=2, choice_weights=None): |
|
self.ops = ops |
|
self.num_layers = num_layers |
|
self.choice_weights = choice_weights |
|
|
|
def __call__(self, img): |
|
|
|
ops = np.random.choice( |
|
self.ops, self.num_layers, replace=False, p=self.choice_weights |
|
) |
|
for op in ops: |
|
img = op(img) |
|
|
|
return img |
|
|
|
|
|
def rand_augment_transform(config_str, hparams): |
|
""" |
|
Create a RandAugment transform |
|
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by |
|
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining |
|
sections, not order sepecific determine |
|
'm' - integer magnitude of rand augment |
|
'n' - integer num layers (number of transform ops selected per image) |
|
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) |
|
'mstd' - float std deviation of magnitude noise applied |
|
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) |
|
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 |
|
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 |
|
|
|
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme |
|
|
|
:return: A PyTorch compatible Transform |
|
""" |
|
magnitude = _MAX_LEVEL |
|
num_layers = 2 |
|
weight_idx = None |
|
transforms = _RAND_TRANSFORMS |
|
config = config_str.split("-") |
|
assert config[0] == "rand" |
|
config = config[1:] |
|
for c in config: |
|
cs = re.split(r"(\d.*)", c) |
|
if len(cs) < 2: |
|
continue |
|
key, val = cs[:2] |
|
if key == "mstd": |
|
|
|
hparams.setdefault("magnitude_std", float(val)) |
|
elif key == "inc": |
|
if bool(val): |
|
transforms = _RAND_INCREASING_TRANSFORMS |
|
elif key == "m": |
|
magnitude = int(val) |
|
elif key == "n": |
|
num_layers = int(val) |
|
elif key == "w": |
|
weight_idx = int(val) |
|
else: |
|
assert False, "Unknown RandAugment config section" |
|
|
|
|
|
|
|
|
|
|
|
ra_ops = rand_augment_ops( |
|
magnitude=magnitude, hparams=hparams, transforms=transforms |
|
) |
|
choice_weights = ( |
|
None if weight_idx is None else _select_rand_weights(weight_idx) |
|
) |
|
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) |
|
|