# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import numpy as np import random import cv2 from fvcore.transforms.transform import Transform class ColorAugSSDTransform(Transform): """ A color related data augmentation used in Single Shot Multibox Detector (SSD). Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang Fu, Alexander C. Berg. SSD: Single Shot MultiBox Detector. ECCV 2016. Implementation based on: https://github.com/weiliu89/caffe/blob /4817bf8b4200b35ada8ed0dc378dceaf38c539e4 /src/caffe/util/im_transforms.cpp https://github.com/chainer/chainercv/blob /7159616642e0be7c5b3ef380b848e16b7e99355b/chainercv /links/model/ssd/transforms.py """ def __init__( self, img_format, brightness_delta=32, contrast_low=0.5, contrast_high=1.5, saturation_low=0.5, saturation_high=1.5, hue_delta=18, ): super().__init__() assert img_format in ["BGR", "RGB"] self.is_rgb = img_format == "RGB" del img_format self._set_attributes(locals()) def apply_coords(self, coords): return coords def apply_segmentation(self, segmentation): return segmentation def apply_image(self, img, interp=None): if self.is_rgb: img = img[:, :, [2, 1, 0]] img = self.brightness(img) if random.randrange(2): img = self.contrast(img) img = self.saturation(img) img = self.hue(img) else: img = self.saturation(img) img = self.hue(img) img = self.contrast(img) if self.is_rgb: img = img[:, :, [2, 1, 0]] return img def convert(self, img, alpha=1, beta=0): img = img.astype(np.float32) * alpha + beta img = np.clip(img, 0, 255) return img.astype(np.uint8) def brightness(self, img): if random.randrange(2): return self.convert( img, beta=random.uniform(-self.brightness_delta, self.brightness_delta) ) return img def contrast(self, img): if random.randrange(2): return self.convert(img, alpha=random.uniform(self.contrast_low, self.contrast_high)) return img def saturation(self, img): if random.randrange(2): img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) img[:, :, 1] = self.convert( img[:, :, 1], alpha=random.uniform(self.saturation_low, self.saturation_high) ) return cv2.cvtColor(img, cv2.COLOR_HSV2BGR) return img def hue(self, img): if random.randrange(2): img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) img[:, :, 0] = ( img[:, :, 0].astype(int) + random.randint(-self.hue_delta, self.hue_delta) ) % 180 return cv2.cvtColor(img, cv2.COLOR_HSV2BGR) return img