Spaces:
Running
Running
| import numpy as np | |
| from typing import List, Union, Tuple, Dict | |
| import random | |
| from PIL import Image | |
| import cv2 | |
| import os.path as osp | |
| from tqdm import tqdm | |
| from panopticapi.utils import rgb2id, id2rgb | |
| from time import time | |
| import traceback | |
| from utils.io_utils import bbox_overlap_area | |
| from utils.logger import LOGGER | |
| from utils.constants import COLOR_PALETTE | |
| class PartitionTree: | |
| def __init__(self, bleft: int, btop: int, bright: int, bbottom: int, parent = None) -> None: | |
| self.left: PartitionTree = None | |
| self.right: PartitionTree = None | |
| self.top: PartitionTree = None | |
| self.bottom: PartitionTree = None | |
| if bright < bleft: | |
| bright = bleft | |
| if bbottom < btop: | |
| bbottom = btop | |
| self.bleft = bleft | |
| self.bright = bright | |
| self.btop = btop | |
| self.bbottom = bbottom | |
| self.parent: PartitionTree = parent | |
| def is_leaf(self): | |
| return self.left is None | |
| def new_partition(self, new_rect: List): | |
| self.left = PartitionTree(self.bleft, self.btop, new_rect[0], self.bbottom, self) | |
| self.top = PartitionTree(self.bleft, self.btop, self.bright, new_rect[1], self) | |
| self.right = PartitionTree(new_rect[2], self.btop, self.bright, self.bbottom, self) | |
| self.bottom = PartitionTree(self.bleft, new_rect[3], self.bright, self.bbottom, self) | |
| if self.parent is not None: | |
| self.root_update_rect(new_rect) | |
| def root_update_rect(self, rect): | |
| root = self.get_root() | |
| root.update_child_rect(rect) | |
| def update_child_rect(self, rect: List): | |
| if self.is_leaf(): | |
| self.update_from_rect(rect) | |
| else: | |
| self.left.update_child_rect(rect) | |
| self.right.update_child_rect(rect) | |
| self.top.update_child_rect(rect) | |
| self.bottom.update_child_rect(rect) | |
| def get_root(self): | |
| if self.parent is not None: | |
| return self.parent.get_root() | |
| else: | |
| return self | |
| def update_from_rect(self, rect: List): | |
| if not self.is_leaf(): | |
| return | |
| ix = min(self.bright, rect[2]) - max(self.bleft, rect[0]) | |
| iy = min(self.bbottom, rect[3]) - max(self.btop, rect[1]) | |
| if not (ix > 0 and iy > 0): | |
| return | |
| new_ltrb0 = np.array([self.bleft, self.btop, self.bright, self.bbottom]) | |
| new_ltrb1 = new_ltrb0.copy() | |
| if rect[0] > self.bleft and rect[0] < self.bright: | |
| new_ltrb0[2] = rect[0] | |
| else: | |
| new_ltrb0[0] = rect[2] | |
| if rect[1] > self.btop and rect[1] < self.bbottom: | |
| new_ltrb1[3]= rect[1] | |
| else: | |
| new_ltrb1[1] = rect[3] | |
| if (new_ltrb0[2:] - new_ltrb0[:2]).prod() > (new_ltrb1[2:] - new_ltrb1[:2]).prod(): | |
| self.bleft, self.btop, self.bright, self.bbottom = new_ltrb0 | |
| else: | |
| self.bleft, self.btop, self.bright, self.bbottom = new_ltrb1 | |
| def width(self) -> int: | |
| return self.bright - self.bleft | |
| def height(self) -> int: | |
| return self.bbottom - self.btop | |
| def prefer_partition(self, tgt_h: int, tgt_w: int): | |
| if self.is_leaf(): | |
| return self, min(self.width / tgt_w, 1.2) * min(self.height / tgt_h, 1.2) | |
| else: | |
| lp, ls = self.left.prefer_partition(tgt_h, tgt_w) | |
| rp, rs = self.right.prefer_partition(tgt_h, tgt_w) | |
| tp, ts = self.top.prefer_partition(tgt_h, tgt_w) | |
| bp, bs = self.bottom.prefer_partition(tgt_h, tgt_w) | |
| preferp = [(p, s) for s, p in sorted(zip([ls, rs, ts, bs],[lp, rp, tp, bp]), key=lambda pair: pair[0], reverse=True)][0] | |
| return preferp | |
| def new_random_pos(self, fg_h: int, fg_w: int, im_h: int, im_w: int, random_sample: bool = False): | |
| extx, exty = int(fg_w / 3), int(fg_h / 3) | |
| extxb, extyb = int(fg_w / 10), int(fg_h / 10) | |
| region_w, region_h = self.width + extx, self.height + exty | |
| downscale_ratio = max(min(region_w / fg_w, region_h / fg_h), 0.8) | |
| if downscale_ratio < 1: | |
| fg_h = int(downscale_ratio * fg_h) | |
| fg_w = int(downscale_ratio * fg_w) | |
| max_x, max_y = self.bright + extx - fg_w, self.bbottom + exty - fg_h | |
| max_x = min(im_w+extxb-fg_w, max_x) | |
| max_y = min(im_h+extyb-fg_h, max_y) | |
| min_x = max(min(self.bright + extx - fg_w, self.bleft - extx), -extx) | |
| min_x = max(-extxb, min_x) | |
| min_y = max(min(self.bbottom + exty - fg_h, self.btop - exty), -exty) | |
| min_y = max(-extyb, min_y) | |
| px, py = min_x, min_y | |
| if min_x < max_x: | |
| if random_sample: | |
| px = random.randint(min_x, max_x) | |
| else: | |
| px = int((min_x + max_x) / 2) | |
| if min_y < max_y: | |
| if random_sample: | |
| py = random.randint(min_y, max_y) | |
| else: | |
| py = int((min_y + max_y) / 2) | |
| return px, py, downscale_ratio | |
| def drawpartition(self, image: np.ndarray, color = None): | |
| if color is None: | |
| color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) | |
| if not self.is_leaf(): | |
| cv2.rectangle(image, (self.bleft, self.btop), (self.bright, self.bbottom), color, 2) | |
| if not self.is_leaf(): | |
| c = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) | |
| self.left.drawpartition(image, c) | |
| self.right.drawpartition(image, c) | |
| self.top.drawpartition(image, c) | |
| self.bottom.drawpartition(image, c) | |
| def paste_one_fg(fg_pil: Image, bg: Image, segments: np.ndarray, px: int, py: int, seg_color: Tuple, cal_area=True): | |
| fg_h, fg_w = fg_pil.height, fg_pil.width | |
| im_h, im_w = bg.height, bg.width | |
| bg.paste(fg_pil, (px, py), mask=fg_pil) | |
| bgx1, bgx2, bgy1, bgy2 = px, px+fg_w, py, py+fg_h | |
| fgx1, fgx2, fgy1, fgy2 = 0, fg_w, 0, fg_h | |
| if bgx1 < 0: | |
| fgx1 = -bgx1 | |
| bgx1 = 0 | |
| if bgy1 < 0: | |
| fgy1 = -bgy1 | |
| bgy1 = 0 | |
| if bgx2 > im_w: | |
| fgx2 = im_w - bgx2 | |
| bgx2 = im_w | |
| if bgy2 > im_h: | |
| fgy2 = im_h - bgy2 | |
| bgy2 = im_h | |
| fg_mask = np.array(fg_pil)[fgy1: fgy2, fgx1: fgx2, 3] > 30 | |
| segments[bgy1: bgy2, bgx1: bgx2][np.where(fg_mask)] = seg_color | |
| if cal_area: | |
| area = fg_mask.sum() | |
| else: | |
| area = 1 | |
| bbox = [bgx1, bgy1, bgx2-bgx1, bgy2-bgy1] | |
| return area, bbox, [bgx1, bgy1, bgx2, bgy2] | |
| def partition_paste(fg_list, bg: Image): | |
| segments_info = [] | |
| fg_list.sort(key = lambda x: x['image'].shape[0] * x['image'].shape[1], reverse=True) | |
| pnode: PartitionTree = None | |
| im_h, im_w = bg.height, bg.width | |
| ptree = PartitionTree(0, 0, bg.width, bg.height) | |
| segments = np.zeros((im_h, im_w, 3), np.uint8) | |
| for ii, fg_dict in enumerate(fg_list): | |
| fg = fg_dict['image'] | |
| fg_h, fg_w = fg.shape[:2] | |
| pnode, _ = ptree.prefer_partition(fg_h, fg_w) | |
| px, py, downscale_ratio = pnode.new_random_pos(fg_h, fg_w, im_h, im_w, True) | |
| fg_pil = Image.fromarray(fg) | |
| if downscale_ratio < 1: | |
| fg_pil = fg_pil.resize((int(fg_w * downscale_ratio), int(fg_h * downscale_ratio)), resample=Image.Resampling.LANCZOS) | |
| # fg_h, fg_w = fg_pil.height, fg_pil.width | |
| seg_color = COLOR_PALETTE[ii] | |
| area, bbox, xyxy = paste_one_fg(fg_pil, bg, segments, px,py, seg_color, cal_area=False) | |
| pnode.new_partition(xyxy) | |
| segments_info.append({ | |
| 'id': rgb2id(seg_color), | |
| 'bbox': bbox, | |
| 'area': area | |
| }) | |
| return segments_info, segments | |
| # if downscale_ratio < 1: | |
| # fg_pil = fg_pil.resize((int(fg_w * downscale_ratio), int(fg_h * downscale_ratio)), resample=Image.Resampling.LANCZOS) | |
| # fg_h, fg_w = fg_pil.height, fg_pil.width | |
| def gen_fg_regbboxes(fg_list: List[Dict], tgt_size: int, min_overlap=0.15, max_overlap=0.8): | |
| def _sample_y(h): | |
| y = (tgt_size - h) // 2 | |
| if y > 0: | |
| yrange = min(y, h // 4) | |
| y += random.randint(-yrange, yrange) | |
| return y | |
| else: | |
| return 0 | |
| shape_list = [] | |
| depth_list = [] | |
| for fg_dict in fg_list: | |
| shape_list.append(fg_dict['image'].shape[:2]) | |
| shape_list = np.array(shape_list) | |
| depth_list = np.random.random(len(fg_list)) | |
| depth_list[shape_list[..., 1] > 0.6 * tgt_size] += 1 | |
| # num_fg = len(fg_list) | |
| # grid_sample = random.random() < 0.4 or num_fg > 6 | |
| # grid_sample = grid_sample and num_fg < 9 and num_fg > 3 | |
| # grid_sample = False | |
| # if grid_sample: | |
| # grid_pos = np.arange(9) | |
| # np.random.shuffle(grid_pos) | |
| # grid_pos = grid_pos[: num_fg] | |
| # grid_x = grid_pos % 3 | |
| # grid_y = grid_pos // 3 | |
| # else: | |
| pos_list = [[0, _sample_y(shape_list[0][0])]] | |
| pre_overlap = 0 | |
| for ii, ((h, w), d) in enumerate(zip(shape_list[1:], depth_list[1:])): | |
| (preh, prew), predepth, (prex, prey) = shape_list[ii], depth_list[ii], pos_list[ii] | |
| isfg = d < predepth | |
| y = _sample_y(h) | |
| x = prex+prew | |
| if isfg: | |
| min_x = max_x = x | |
| if pre_overlap < max_overlap: | |
| min_x -= (max_overlap - pre_overlap) * prew | |
| min_x = int(min_x) | |
| if pre_overlap < min_overlap: | |
| max_x -= (min_overlap - pre_overlap) * prew | |
| max_x = int(max_x) | |
| x = random.randint(min_x, max_x) | |
| pre_overlap = 0 | |
| else: | |
| overlap = random.uniform(min_overlap, max_overlap) | |
| x -= int(overlap * w) | |
| area = h * w | |
| overlap_area = bbox_overlap_area([x, y, w, h], [prex, prey, prew, preh]) | |
| pre_overlap = overlap_area / area | |
| pos_list.append([x, y]) | |
| pos_list = np.array(pos_list) | |
| last_x2 = pos_list[-1][0] + shape_list[-1][1] | |
| valid_shiftx = tgt_size - last_x2 | |
| if valid_shiftx > 0: | |
| shiftx = random.randint(0, valid_shiftx) | |
| pos_list[:, 0] += shiftx | |
| else: | |
| pos_list[:, 0] += valid_shiftx // 2 | |
| for pos, fg_dict, depth in zip(pos_list, fg_list, depth_list): | |
| fg_dict['pos'] = pos | |
| fg_dict['depth'] = depth | |
| fg_list.sort(key=lambda x: x['depth'], reverse=True) | |
| def regular_paste(fg_list, bg: Image, regen_bboxes=False): | |
| segments_info = [] | |
| im_h, im_w = bg.height, bg.width | |
| if regen_bboxes: | |
| random.shuffle(fg_list) | |
| gen_fg_regbboxes(fg_list, im_h) | |
| segments = np.zeros((im_h, im_w, 3), np.uint8) | |
| for ii, fg_dict in enumerate(fg_list): | |
| fg = fg_dict['image'] | |
| px, py = fg_dict.pop('pos') | |
| fg_pil = Image.fromarray(fg) | |
| seg_color = COLOR_PALETTE[ii] | |
| area, bbox, xyxy = paste_one_fg(fg_pil, bg, segments, px,py, seg_color, cal_area=True) | |
| segments_info.append({ | |
| 'id': rgb2id(seg_color), | |
| 'bbox': bbox, | |
| 'area': area | |
| }) | |
| return segments_info, segments |