Spaces:
Running
Running
import albumentations as A | |
from torch.utils.data import Dataset, DataLoader | |
import pycocotools.mask as maskUtils | |
from pycocotools.coco import COCO | |
import random | |
import os.path as osp | |
import cv2 | |
import numpy as np | |
from scipy.ndimage import distance_transform_bf, distance_transform_edt, distance_transform_cdt | |
def is_grey(img: np.ndarray): | |
if len(img.shape) == 3 and img.shape[2] == 3: | |
return False | |
else: | |
return True | |
def square_pad_resize(img: np.ndarray, tgt_size: int, pad_value = (0, 0, 0)): | |
h, w = img.shape[:2] | |
pad_h, pad_w = 0, 0 | |
# make square image | |
if w < h: | |
pad_w = h - w | |
w += pad_w | |
elif h < w: | |
pad_h = w - h | |
h += pad_h | |
pad_size = tgt_size - h | |
if pad_size > 0: | |
pad_h += pad_size | |
pad_w += pad_size | |
if pad_h > 0 or pad_w > 0: | |
c = 1 | |
if is_grey(img): | |
if isinstance(pad_value, tuple): | |
pad_value = pad_value[0] | |
else: | |
if isinstance(pad_value, int): | |
pad_value = (pad_value, pad_value, pad_value) | |
img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=pad_value) | |
resize_ratio = tgt_size / img.shape[0] | |
if resize_ratio < 1: | |
img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA) | |
elif resize_ratio > 1: | |
img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_LINEAR) | |
return img, resize_ratio, pad_h, pad_w | |
class MaskRefineDataset(Dataset): | |
def __init__(self, | |
refine_ann_path: str, | |
data_root: str, | |
load_instance_mask: bool = True, | |
aug_ins_prob: float = 0., | |
ins_rect_prob: float = 0., | |
output_size: int = 720, | |
augmentation: bool = False, | |
with_distance: bool = False): | |
self.load_instance_mask = load_instance_mask | |
self.ann_util = COCO(refine_ann_path) | |
self.img_ids = self.ann_util.getImgIds() | |
self.set_load_method(load_instance_mask) | |
self.data_root = data_root | |
self.ins_rect_prob = ins_rect_prob | |
self.aug_ins_prob = aug_ins_prob | |
self.augmentation = augmentation | |
if augmentation: | |
transform = [ | |
A.OpticalDistortion(), | |
A.HorizontalFlip(), | |
A.CLAHE(), | |
A.Posterize(), | |
A.CropAndPad(percent=0.1, p=0.3, pad_mode=cv2.BORDER_CONSTANT, pad_cval=0, pad_cval_mask=0, keep_size=True), | |
A.RandomContrast(), | |
A.Rotate(30, p=0.3, mask_value=0, border_mode=cv2.BORDER_CONSTANT) | |
] | |
self._aug_transform = A.Compose(transform) | |
else: | |
self._aug_transform = None | |
self.output_size = output_size | |
self.with_distance = with_distance | |
def set_output_size(self, size: int): | |
self.output_size = size | |
def set_load_method(self, load_instance_mask: bool): | |
if load_instance_mask: | |
self._load_mask = self._load_with_instance | |
else: | |
self._load_mask = self._load_without_instance | |
def __getitem__(self, idx: int): | |
img_id = self.img_ids[idx] | |
img_meta = self.ann_util.imgs[img_id] | |
img_path = osp.join(self.data_root, img_meta['file_name']) | |
img = cv2.imread(img_path) | |
annids = self.ann_util.getAnnIds([img_id]) | |
if len(annids) > 0: | |
ann = random.choice(annids) | |
ann = self.ann_util.anns[ann] | |
assert ann['image_id'] == img_id | |
else: | |
ann = None | |
return self._load_mask(img, ann) | |
def transform(self, img: np.ndarray, mask: np.ndarray, ins_seg: np.ndarray = None) -> dict: | |
if ins_seg is not None: | |
use_seg = True | |
else: | |
use_seg = False | |
if self.augmentation: | |
masks = [mask] | |
if use_seg: | |
masks.append(ins_seg) | |
data = self._aug_transform(image=img, masks=masks) | |
img = data['image'] | |
masks = data['masks'] | |
mask = masks[0] | |
if use_seg: | |
ins_seg = masks[1] | |
img = square_pad_resize(img, self.output_size, random.randint(0, 255))[0] | |
mask = square_pad_resize(mask, self.output_size, 0)[0] | |
if ins_seg is not None: | |
ins_seg = square_pad_resize(ins_seg, self.output_size, 0)[0] | |
img = (img.astype(np.float32) / 255.).transpose((2, 0, 1)) | |
mask = mask[None, ...] | |
if use_seg: | |
ins_seg = ins_seg[None, ...] | |
img = np.concatenate((img, ins_seg), axis=0) | |
data = {'img': img, 'mask': mask} | |
if self.with_distance: | |
dist = distance_transform_edt(mask[0]) | |
dist_max = dist.max() | |
if dist_max != 0: | |
dist = 1 - dist / dist_max | |
# diff_mat = cv2.bitwise_xor(mask[0], ins_seg[0]) | |
# dist = dist + diff_mat + 0.2 | |
dist = dist + 0.2 | |
dist = dist.size / (dist.sum() + 1) * dist | |
dist = np.clip(dist, 0, 20) | |
else: | |
dist = np.ones_like(dist) | |
# print(dist.max(), dist.min()) | |
data['dist_weight'] = dist[None, ...] | |
return data | |
def _load_with_instance(self, img: np.ndarray, ann: dict): | |
if ann is None: | |
mask = np.zeros(img.shape[:2], dtype=np.float32) | |
ins_seg = mask | |
else: | |
mask = maskUtils.decode(ann['segmentation']).astype(np.float32) | |
if self.augmentation and random.random() < self.ins_rect_prob: | |
ins_seg = np.zeros_like(mask) | |
bbox = [int(b) for b in ann['bbox']] | |
ins_seg[bbox[1]: bbox[1] + bbox[3], bbox[0]: bbox[0] + bbox[2]] = 1 | |
elif len(ann['pred_segmentations']) > 0: | |
ins_seg = random.choice(ann['pred_segmentations']) | |
ins_seg = maskUtils.decode(ins_seg).astype(np.float32) | |
else: | |
ins_seg = mask | |
if self.augmentation and random.random() < self.aug_ins_prob: | |
ksize = random.choice([1, 3, 5, 7]) | |
ksize = ksize * 2 + 1 | |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ksize=(ksize, ksize)) | |
if random.random() < 0.5: | |
ins_seg = cv2.dilate(ins_seg, kernel) | |
else: | |
ins_seg = cv2.erode(ins_seg, kernel) | |
return self.transform(img, mask, ins_seg) | |
def _load_without_instance(self, img: np.ndarray, ann: dict): | |
if ann is None: | |
mask = np.zeros(img.shape[:2], dtype=np.float32) | |
else: | |
mask = maskUtils.decode(ann['segmentation']).astype(np.float32) | |
return self.transform(img, mask) | |
def __len__(self): | |
return len(self.img_ids) | |
if __name__ == '__main__': | |
ann_path = r'workspace/test_syndata/annotations/refine_train.json' | |
data_root = r'workspace/test_syndata/train' | |
ann_path = r'workspace/test_syndata/annotations/refine_train.json' | |
data_root = r'workspace/test_syndata/train' | |
aug_ins_prob = 0.5 | |
load_instance_mask = True | |
ins_rect_prob = 0.25 | |
output_size = 640 | |
augmentation = True | |
random.seed(0) | |
md = MaskRefineDataset(ann_path, data_root, load_instance_mask, aug_ins_prob, ins_rect_prob, output_size, augmentation, with_distance=True) | |
dl = DataLoader(md, batch_size=1, shuffle=False, persistent_workers=True, | |
num_workers=1, pin_memory=True) | |
for data in dl: | |
img = data['img'].cpu().numpy() | |
img = (img[0, :3].transpose((1, 2, 0)) * 255).astype(np.uint8) | |
mask = (data['mask'].cpu().numpy()[0][0] * 255).astype(np.uint8) | |
if load_instance_mask: | |
ins = (data['img'].cpu().numpy()[0][3] * 255).astype(np.uint8) | |
cv2.imshow('ins', ins) | |
dist = data['dist_weight'].cpu().numpy()[0][0] | |
dist = (dist / dist.max() * 255).astype(np.uint8) | |
cv2.imshow('img', img) | |
cv2.imshow('mask', mask) | |
cv2.imshow('dist_weight', dist) | |
cv2.waitKey(0) | |
# cv2.imwrite('') |