Spaces:
Sleeping
Sleeping
| import albumentations as A | |
| from torch.utils.data import Dataset, DataLoader | |
| import pycocotools.mask as maskUtils | |
| from pycocotools.coco import COCO | |
| import random | |
| import os.path as osp | |
| import cv2 | |
| import numpy as np | |
| from scipy.ndimage import distance_transform_bf, distance_transform_edt, distance_transform_cdt | |
| def is_grey(img: np.ndarray): | |
| if len(img.shape) == 3 and img.shape[2] == 3: | |
| return False | |
| else: | |
| return True | |
| def square_pad_resize(img: np.ndarray, tgt_size: int, pad_value = (0, 0, 0)): | |
| h, w = img.shape[:2] | |
| pad_h, pad_w = 0, 0 | |
| # make square image | |
| if w < h: | |
| pad_w = h - w | |
| w += pad_w | |
| elif h < w: | |
| pad_h = w - h | |
| h += pad_h | |
| pad_size = tgt_size - h | |
| if pad_size > 0: | |
| pad_h += pad_size | |
| pad_w += pad_size | |
| if pad_h > 0 or pad_w > 0: | |
| c = 1 | |
| if is_grey(img): | |
| if isinstance(pad_value, tuple): | |
| pad_value = pad_value[0] | |
| else: | |
| if isinstance(pad_value, int): | |
| pad_value = (pad_value, pad_value, pad_value) | |
| img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=pad_value) | |
| resize_ratio = tgt_size / img.shape[0] | |
| if resize_ratio < 1: | |
| img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA) | |
| elif resize_ratio > 1: | |
| img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_LINEAR) | |
| return img, resize_ratio, pad_h, pad_w | |
| class MaskRefineDataset(Dataset): | |
| def __init__(self, | |
| refine_ann_path: str, | |
| data_root: str, | |
| load_instance_mask: bool = True, | |
| aug_ins_prob: float = 0., | |
| ins_rect_prob: float = 0., | |
| output_size: int = 720, | |
| augmentation: bool = False, | |
| with_distance: bool = False): | |
| self.load_instance_mask = load_instance_mask | |
| self.ann_util = COCO(refine_ann_path) | |
| self.img_ids = self.ann_util.getImgIds() | |
| self.set_load_method(load_instance_mask) | |
| self.data_root = data_root | |
| self.ins_rect_prob = ins_rect_prob | |
| self.aug_ins_prob = aug_ins_prob | |
| self.augmentation = augmentation | |
| if augmentation: | |
| transform = [ | |
| A.OpticalDistortion(), | |
| A.HorizontalFlip(), | |
| A.CLAHE(), | |
| A.Posterize(), | |
| A.CropAndPad(percent=0.1, p=0.3, pad_mode=cv2.BORDER_CONSTANT, pad_cval=0, pad_cval_mask=0, keep_size=True), | |
| A.RandomContrast(), | |
| A.Rotate(30, p=0.3, mask_value=0, border_mode=cv2.BORDER_CONSTANT) | |
| ] | |
| self._aug_transform = A.Compose(transform) | |
| else: | |
| self._aug_transform = None | |
| self.output_size = output_size | |
| self.with_distance = with_distance | |
| def set_output_size(self, size: int): | |
| self.output_size = size | |
| def set_load_method(self, load_instance_mask: bool): | |
| if load_instance_mask: | |
| self._load_mask = self._load_with_instance | |
| else: | |
| self._load_mask = self._load_without_instance | |
| def __getitem__(self, idx: int): | |
| img_id = self.img_ids[idx] | |
| img_meta = self.ann_util.imgs[img_id] | |
| img_path = osp.join(self.data_root, img_meta['file_name']) | |
| img = cv2.imread(img_path) | |
| annids = self.ann_util.getAnnIds([img_id]) | |
| if len(annids) > 0: | |
| ann = random.choice(annids) | |
| ann = self.ann_util.anns[ann] | |
| assert ann['image_id'] == img_id | |
| else: | |
| ann = None | |
| return self._load_mask(img, ann) | |
| def transform(self, img: np.ndarray, mask: np.ndarray, ins_seg: np.ndarray = None) -> dict: | |
| if ins_seg is not None: | |
| use_seg = True | |
| else: | |
| use_seg = False | |
| if self.augmentation: | |
| masks = [mask] | |
| if use_seg: | |
| masks.append(ins_seg) | |
| data = self._aug_transform(image=img, masks=masks) | |
| img = data['image'] | |
| masks = data['masks'] | |
| mask = masks[0] | |
| if use_seg: | |
| ins_seg = masks[1] | |
| img = square_pad_resize(img, self.output_size, random.randint(0, 255))[0] | |
| mask = square_pad_resize(mask, self.output_size, 0)[0] | |
| if ins_seg is not None: | |
| ins_seg = square_pad_resize(ins_seg, self.output_size, 0)[0] | |
| img = (img.astype(np.float32) / 255.).transpose((2, 0, 1)) | |
| mask = mask[None, ...] | |
| if use_seg: | |
| ins_seg = ins_seg[None, ...] | |
| img = np.concatenate((img, ins_seg), axis=0) | |
| data = {'img': img, 'mask': mask} | |
| if self.with_distance: | |
| dist = distance_transform_edt(mask[0]) | |
| dist_max = dist.max() | |
| if dist_max != 0: | |
| dist = 1 - dist / dist_max | |
| # diff_mat = cv2.bitwise_xor(mask[0], ins_seg[0]) | |
| # dist = dist + diff_mat + 0.2 | |
| dist = dist + 0.2 | |
| dist = dist.size / (dist.sum() + 1) * dist | |
| dist = np.clip(dist, 0, 20) | |
| else: | |
| dist = np.ones_like(dist) | |
| # print(dist.max(), dist.min()) | |
| data['dist_weight'] = dist[None, ...] | |
| return data | |
| def _load_with_instance(self, img: np.ndarray, ann: dict): | |
| if ann is None: | |
| mask = np.zeros(img.shape[:2], dtype=np.float32) | |
| ins_seg = mask | |
| else: | |
| mask = maskUtils.decode(ann['segmentation']).astype(np.float32) | |
| if self.augmentation and random.random() < self.ins_rect_prob: | |
| ins_seg = np.zeros_like(mask) | |
| bbox = [int(b) for b in ann['bbox']] | |
| ins_seg[bbox[1]: bbox[1] + bbox[3], bbox[0]: bbox[0] + bbox[2]] = 1 | |
| elif len(ann['pred_segmentations']) > 0: | |
| ins_seg = random.choice(ann['pred_segmentations']) | |
| ins_seg = maskUtils.decode(ins_seg).astype(np.float32) | |
| else: | |
| ins_seg = mask | |
| if self.augmentation and random.random() < self.aug_ins_prob: | |
| ksize = random.choice([1, 3, 5, 7]) | |
| ksize = ksize * 2 + 1 | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ksize=(ksize, ksize)) | |
| if random.random() < 0.5: | |
| ins_seg = cv2.dilate(ins_seg, kernel) | |
| else: | |
| ins_seg = cv2.erode(ins_seg, kernel) | |
| return self.transform(img, mask, ins_seg) | |
| def _load_without_instance(self, img: np.ndarray, ann: dict): | |
| if ann is None: | |
| mask = np.zeros(img.shape[:2], dtype=np.float32) | |
| else: | |
| mask = maskUtils.decode(ann['segmentation']).astype(np.float32) | |
| return self.transform(img, mask) | |
| def __len__(self): | |
| return len(self.img_ids) | |
| if __name__ == '__main__': | |
| ann_path = r'workspace/test_syndata/annotations/refine_train.json' | |
| data_root = r'workspace/test_syndata/train' | |
| ann_path = r'workspace/test_syndata/annotations/refine_train.json' | |
| data_root = r'workspace/test_syndata/train' | |
| aug_ins_prob = 0.5 | |
| load_instance_mask = True | |
| ins_rect_prob = 0.25 | |
| output_size = 640 | |
| augmentation = True | |
| random.seed(0) | |
| md = MaskRefineDataset(ann_path, data_root, load_instance_mask, aug_ins_prob, ins_rect_prob, output_size, augmentation, with_distance=True) | |
| dl = DataLoader(md, batch_size=1, shuffle=False, persistent_workers=True, | |
| num_workers=1, pin_memory=True) | |
| for data in dl: | |
| img = data['img'].cpu().numpy() | |
| img = (img[0, :3].transpose((1, 2, 0)) * 255).astype(np.uint8) | |
| mask = (data['mask'].cpu().numpy()[0][0] * 255).astype(np.uint8) | |
| if load_instance_mask: | |
| ins = (data['img'].cpu().numpy()[0][3] * 255).astype(np.uint8) | |
| cv2.imshow('ins', ins) | |
| dist = data['dist_weight'].cpu().numpy()[0][0] | |
| dist = (dist / dist.max() * 255).astype(np.uint8) | |
| cv2.imshow('img', img) | |
| cv2.imshow('mask', mask) | |
| cv2.imshow('dist_weight', dist) | |
| cv2.waitKey(0) | |
| # cv2.imwrite('') |