Spaces:
Runtime error
Runtime error
import numpy as np | |
from ..builder import PIPELINES | |
class InstaBoost(object): | |
r"""Data augmentation method in `InstaBoost: Boosting Instance | |
Segmentation Via Probability Map Guided Copy-Pasting | |
<https://arxiv.org/abs/1908.07801>`_. | |
Refer to https://github.com/GothicAi/Instaboost for implementation details. | |
""" | |
def __init__(self, | |
action_candidate=('normal', 'horizontal', 'skip'), | |
action_prob=(1, 0, 0), | |
scale=(0.8, 1.2), | |
dx=15, | |
dy=15, | |
theta=(-1, 1), | |
color_prob=0.5, | |
hflag=False, | |
aug_ratio=0.5): | |
try: | |
import instaboostfast as instaboost | |
except ImportError: | |
raise ImportError( | |
'Please run "pip install instaboostfast" ' | |
'to install instaboostfast first for instaboost augmentation.') | |
self.cfg = instaboost.InstaBoostConfig(action_candidate, action_prob, | |
scale, dx, dy, theta, | |
color_prob, hflag) | |
self.aug_ratio = aug_ratio | |
def _load_anns(self, results): | |
labels = results['ann_info']['labels'] | |
masks = results['ann_info']['masks'] | |
bboxes = results['ann_info']['bboxes'] | |
n = len(labels) | |
anns = [] | |
for i in range(n): | |
label = labels[i] | |
bbox = bboxes[i] | |
mask = masks[i] | |
x1, y1, x2, y2 = bbox | |
# assert (x2 - x1) >= 1 and (y2 - y1) >= 1 | |
bbox = [x1, y1, x2 - x1, y2 - y1] | |
anns.append({ | |
'category_id': label, | |
'segmentation': mask, | |
'bbox': bbox | |
}) | |
return anns | |
def _parse_anns(self, results, anns, img): | |
gt_bboxes = [] | |
gt_labels = [] | |
gt_masks_ann = [] | |
for ann in anns: | |
x1, y1, w, h = ann['bbox'] | |
# TODO: more essential bug need to be fixed in instaboost | |
if w <= 0 or h <= 0: | |
continue | |
bbox = [x1, y1, x1 + w, y1 + h] | |
gt_bboxes.append(bbox) | |
gt_labels.append(ann['category_id']) | |
gt_masks_ann.append(ann['segmentation']) | |
gt_bboxes = np.array(gt_bboxes, dtype=np.float32) | |
gt_labels = np.array(gt_labels, dtype=np.int64) | |
results['ann_info']['labels'] = gt_labels | |
results['ann_info']['bboxes'] = gt_bboxes | |
results['ann_info']['masks'] = gt_masks_ann | |
results['img'] = img | |
return results | |
def __call__(self, results): | |
img = results['img'] | |
orig_type = img.dtype | |
anns = self._load_anns(results) | |
if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]): | |
try: | |
import instaboostfast as instaboost | |
except ImportError: | |
raise ImportError('Please run "pip install instaboostfast" ' | |
'to install instaboostfast first.') | |
anns, img = instaboost.get_new_data( | |
anns, img.astype(np.uint8), self.cfg, background=None) | |
results = self._parse_anns(results, anns, img.astype(orig_type)) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f'(cfg={self.cfg}, aug_ratio={self.aug_ratio})' | |
return repr_str | |