"""Mixup detection dataset wrapper.""" from __future__ import absolute_import import numpy as np import torch import torch.utils.data as data class MixupDetection(data.Dataset): """Detection dataset wrapper that performs mixup for normal dataset. Parameters ---------- dataset : mx.gluon.data.Dataset Gluon dataset object. mixup : callable random generator, e.g. np.random.uniform A random mixup ratio sampler, preferably a random generator from numpy.random A random float will be sampled each time with mixup(*args). Use None to disable. *args : list Additional arguments for mixup random sampler. """ def __init__(self, dataset, mixup=None, preproc=None, *args): super().__init__(dataset.input_dim) self._dataset = dataset self.preproc = preproc self._mixup = mixup self._mixup_args = args def set_mixup(self, mixup=None, *args): """Set mixup random sampler, use None to disable. Parameters ---------- mixup : callable random generator, e.g. np.random.uniform A random mixup ratio sampler, preferably a random generator from numpy.random A random float will be sampled each time with mixup(*args) *args : list Additional arguments for mixup random sampler. """ self._mixup = mixup self._mixup_args = args def __len__(self): return len(self._dataset) @Dataset.resize_getitem def __getitem__(self, idx): self._dataset._input_dim = self.input_dim # first image img1, label1, _, _ = self._dataset.pull_item(idx) lambd = 1 # draw a random lambda ratio from distribution if self._mixup is not None: lambd = max(0, min(1, self._mixup(*self._mixup_args))) if lambd >= 1: weights1 = np.ones((label1.shape[0], 1)) label1 = np.hstack((label1, weights1)) height, width, _ = img1.shape img_info = (width, height) if self.preproc is not None: img_o, target_o = self.preproc(img1, label1, self.input_dim) return img_o, target_o, img_info, idx # second image idx2 = int(np.random.choice(np.delete(np.arange(len(self)), idx))) img2, label2, _, _ = self._dataset.pull_item(idx2) # mixup two images height = max(img1.shape[0], img2.shape[0]) width = max(img1.shape[1], img2.shape[1]) mix_img = np.zeros((height, width, 3), dtype=np.float32) mix_img[: img1.shape[0], : img1.shape[1], :] = img1.astype(np.float32) * lambd mix_img[: img2.shape[0], : img2.shape[1], :] += img2.astype(np.float32) * (1.0 - lambd) mix_img = mix_img.astype(np.uint8) y1 = np.hstack((label1, np.full((label1.shape[0], 1), lambd))) y2 = np.hstack((label2, np.full((label2.shape[0], 1), 1.0 - lambd))) mix_label = np.vstack((y1, y2)) if self.preproc is not None: mix_img, padded_labels = self.preproc(mix_img, mix_label, self.input_dim) img_info = (width, height) return mix_img, padded_labels, img_info, idx def pull_item(self, idx): self._dataset._input_dim = self.input_dim # first image img1, label1, _, _ = self._dataset.pull_item(idx) lambd = 1 # draw a random lambda ratio from distribution if self._mixup is not None: lambd = max(0, min(1, self._mixup(*self._mixup_args))) if lambd >= 1: weights1 = np.ones((label1.shape[0], 1)) label1 = np.hstack((label1, weights1)) height, width, _ = img1.shape img_info = (width, height) if self.preproc is not None: img_o, target_o = self.preproc(img1, label1, self.input_dim) return img_o, target_o, img_info, idx # second image idx2 = int(np.random.choice(np.delete(np.arange(len(self)), idx))) img2, label2 = self._dataset.pull_item(idx2) # mixup two images height = max(img1.shape[0], img2.shape[0]) width = max(img1.shape[1], img2.shape[1]) mix_img = np.zeros((height, width, 3), dtype=np.float32) mix_img[: img1.shape[0], : img1.shape[1], :] = img1.astype(np.float32) * lambd mix_img[: img2.shape[0], : img2.shape[1], :] += img2.astype(np.float32) * (1.0 - lambd) mix_img = mix_img.astype(np.uint8) y1 = np.hstack((label1, np.full((label1.shape[0], 1), lambd))) y2 = np.hstack((label2, np.full((label2.shape[0], 1), 1.0 - lambd))) mix_label = np.vstack((y1, y2)) if self.preproc is not None: mix_img, padded_labels = self.preproc(mix_img, mix_label, self.input_dim) img_info = (width, height) return mix_img, padded_labels, img_info, idx