|
import os |
|
from pycocotools import mask as mask_util |
|
import json |
|
import numpy as np |
|
import cv2 |
|
from distinctipy import distinctipy |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
from types import MethodType |
|
import json |
|
import random |
|
|
|
import torch |
|
import torchvision |
|
from detectron2.data import MetadataCatalog |
|
from detectron2.structures import BitMasks, PolygonMasks |
|
from detectron2.utils.visualizer import ColorMode, Visualizer |
|
from detectron2.data.detection_utils import read_image |
|
|
|
from third_parts.APE.build_ape import build_ape_predictor |
|
|
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
|
best_ratio_diff = float('inf') |
|
best_ratio = (1, 1) |
|
area = width * height |
|
for ratio in target_ratios: |
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
if ratio_diff < best_ratio_diff: |
|
best_ratio_diff = ratio_diff |
|
best_ratio = ratio |
|
elif ratio_diff == best_ratio_diff: |
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
|
best_ratio = ratio |
|
return best_ratio |
|
|
|
def sample_points(box, mask, min_points=3, max_points=16, dense_max_points=32): |
|
x0, y0, w, h = box |
|
aspect_ratio = w / h |
|
|
|
|
|
target_ratios = set( |
|
(i, j) for n in range(min_points, max_points + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
|
i * j <= max_points and i * j >= min_points) |
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio( |
|
aspect_ratio, target_ratios, w, h, 50) |
|
width_bin = w / target_aspect_ratio[0] |
|
height_bin = h / target_aspect_ratio[1] |
|
|
|
ret_points = [] |
|
for wi in range(target_aspect_ratio[0]): |
|
xi = x0 + (wi+0.5) * width_bin |
|
for hi in range(target_aspect_ratio[1]): |
|
yi = y0 + (hi+0.5) * height_bin |
|
if mask[int(yi), int(xi)] > 0: |
|
ret_points.append((xi, yi)) |
|
|
|
|
|
temp_points = [] |
|
for wi in range(int(x0), int(x0+w)): |
|
for hi in range(int(y0), int(y0+h)): |
|
if mask[int(hi), int(wi)] > 0: |
|
temp_points.append((wi, hi)) |
|
if len(temp_points)//dense_max_points < 1: |
|
uniform_indices = list(range(0, len(temp_points))) |
|
else: |
|
uniform_indices = list(range(0, len(temp_points), len(temp_points)//dense_max_points)) |
|
additional_points = [temp_points[uniform_idx] for uniform_idx in uniform_indices[1:-1]] |
|
|
|
ret_points = ret_points + additional_points |
|
return ret_points |
|
|
|
|
|
def mask_iou(masks, chunk_size=50, chunk_mode=False): |
|
masks1 = masks.unsqueeze(1).char() |
|
masks2 = masks.unsqueeze(0).char() |
|
|
|
if not chunk_mode: |
|
intersection = (masks1 * masks2) |
|
union = (masks1 + masks2 - intersection).sum(-1).sum(-1) |
|
intersection = intersection.sum(-1).sum(-1) |
|
return intersection, union |
|
|
|
def chunk_mask_iou(_chunk_size=50): |
|
|
|
num_chunks = masks1.shape[0] // _chunk_size |
|
if masks1.shape[0] % _chunk_size > 0: |
|
num_chunks += 1 |
|
|
|
row_chunks_intersection, row_chunks_union = [], [] |
|
for row_idx in range(num_chunks): |
|
col_chunks_intersection, col_chunks_union = [], [] |
|
masks1_chunk = masks1[row_idx*_chunk_size:(row_idx+1)*_chunk_size] |
|
for col_idx in range(num_chunks): |
|
masks2_chunk = masks2[:, col_idx*_chunk_size:(col_idx+1)*_chunk_size] |
|
try: |
|
intersection = masks1_chunk * masks2_chunk |
|
temp_sum = masks1_chunk + masks2_chunk |
|
union = (temp_sum - intersection).sum(-1).sum(-1) |
|
intersection = intersection.sum(-1).sum(-1) |
|
except torch.cuda.OutOfMemoryError: |
|
return False, None, None |
|
col_chunks_intersection.append(intersection) |
|
col_chunks_union.append(union) |
|
row_chunks_intersection.append(torch.cat(col_chunks_intersection, dim=1)) |
|
row_chunks_union.append(torch.cat(col_chunks_union, dim=1)) |
|
intersection = torch.cat(row_chunks_intersection, dim=0) |
|
union = torch.cat(row_chunks_union, dim=0) |
|
return True, intersection, union |
|
|
|
for c_size in [chunk_size, chunk_size//2, chunk_size//4]: |
|
is_ok, intersection, union = chunk_mask_iou(c_size) |
|
if not is_ok: |
|
continue |
|
return intersection, union |
|
|
|
def mask_iou_v2(masks1, masks2, chunk_size=50, chunk_mode=False): |
|
masks1 = masks1.unsqueeze(1).char() |
|
masks2 = masks2.unsqueeze(0).char() |
|
|
|
if not chunk_mode: |
|
intersection = (masks1 * masks2) |
|
union = (masks1 + masks2 - intersection).sum(-1).sum(-1) |
|
intersection = intersection.sum(-1).sum(-1) |
|
|
|
return intersection, union |
|
|
|
def chunk_mask_iou(_chunk_size=50): |
|
num_chunks1 = masks1.shape[0] // _chunk_size |
|
if masks1.shape[0] % _chunk_size > 0: |
|
num_chunks1 += 1 |
|
|
|
num_chunks2 = masks2.shape[1] // _chunk_size |
|
if masks2.shape[0] % _chunk_size > 0: |
|
num_chunks2 += 1 |
|
|
|
row_chunks_intersection, row_chunks_union = [], [] |
|
for row_idx in range(num_chunks1): |
|
col_chunks_intersection, col_chunks_union = [], [] |
|
masks1_chunk = masks1[row_idx*_chunk_size:(row_idx+1)*_chunk_size] |
|
for col_idx in range(num_chunks2): |
|
masks2_chunk = masks2[:, col_idx*_chunk_size:(col_idx+1)*_chunk_size] |
|
try: |
|
intersection = masks1_chunk * masks2_chunk |
|
temp_sum = masks1_chunk + masks2_chunk |
|
union = (temp_sum - intersection).sum(-1).sum(-1) |
|
intersection = intersection.sum(-1).sum(-1) |
|
except torch.cuda.OutOfMemoryError: |
|
return False, None, None |
|
col_chunks_intersection.append(intersection) |
|
col_chunks_union.append(union) |
|
row_chunks_intersection.append(torch.cat(col_chunks_intersection, dim=1)) |
|
row_chunks_union.append(torch.cat(col_chunks_union, dim=1)) |
|
intersection = torch.cat(row_chunks_intersection, dim=0) |
|
union = torch.cat(row_chunks_union, dim=0) |
|
return True, intersection, union |
|
|
|
for c_size in [chunk_size, chunk_size//2, chunk_size//4]: |
|
is_ok, intersection, union = chunk_mask_iou(c_size) |
|
if not is_ok: |
|
continue |
|
return intersection, union |
|
|
|
return intersection, union |
|
|
|
|
|
def mask_area(masks, chunk_size=50, chunk_mode=False): |
|
if not chunk_mode: |
|
return masks.sum(-1).sum(-1) |
|
|
|
num_chunks = masks.shape[0] // chunk_size |
|
if masks.shape[0] % chunk_size > 0: |
|
num_chunks += 1 |
|
|
|
areas = [] |
|
for i in range(num_chunks): |
|
masks_i = masks[i*chunk_size:(i+1)*chunk_size] |
|
areas.append(masks_i.sum(-1).sum(-1)) |
|
return torch.cat(areas, dim=0) |
|
|
|
|
|
def run_on_image(image_file, anno_file, save_path, sam_predictor, sam_auto_mask_generator): |
|
if not os.path.exists(image_file): |
|
return None |
|
file_name = os.path.basename(image_file).split('.')[0] |
|
with open(anno_file, 'r') as f: |
|
json_results = json.load(f) |
|
|
|
sam_image = cv2.imread(image_file) |
|
ori_height, ori_width = sam_image.shape[:2] |
|
sam_image = cv2.cvtColor(sam_image, cv2.COLOR_BGR2RGB) |
|
|
|
ori_image = Image.open(image_file) |
|
for ins_anno in json_results: |
|
root_ins_id = ins_anno['ins_id'] |
|
|
|
object_mask = ins_anno['segmentation'] |
|
if isinstance(object_mask["counts"], list): |
|
object_mask = mask_util.frPyObjects(object_mask, object_mask["size"][0], object_mask["size"][1]) |
|
root_mask = mask_util.decode(object_mask) |
|
root_mask = root_mask.astype(np.uint8).squeeze() |
|
root_mask = torch.from_numpy(root_mask).unsqueeze(0) |
|
root_bbox = torchvision.ops.masks_to_boxes(root_mask) |
|
|
|
|
|
root_bbox = root_bbox[0].numpy().tolist() |
|
box_w = root_bbox[2] - root_bbox[0] |
|
box_h = root_bbox[3] - root_bbox[1] |
|
loose_box_x0 = int(root_bbox[0] - box_w // 4) |
|
loose_box_y0 = int(root_bbox[1] - box_h // 4) |
|
loose_box_x1 = int(root_bbox[2] + box_w // 4) |
|
loose_box_y1 = int(root_bbox[3] + box_h // 4) |
|
loose_box_x0 = loose_box_x0 if loose_box_x0 > 0 else 0 |
|
loose_box_y0 = loose_box_y0 if loose_box_y0 > 0 else 0 |
|
loose_box_x1 = loose_box_x1 if loose_box_x1 < ori_width else ori_width |
|
loose_box_y1 = loose_box_y1 if loose_box_y1 < ori_height else ori_height |
|
|
|
loose_box_w = loose_box_x1 - loose_box_x0 |
|
loose_box_h = loose_box_y1 - loose_box_y0 |
|
if not (loose_box_w >= box_w and loose_box_h >= box_h): |
|
continue |
|
|
|
if loose_box_w < 256: |
|
padded_length_w = 256 - loose_box_w |
|
left_padded = padded_length_w // 2 |
|
right_padded = padded_length_w - left_padded |
|
if loose_box_x0 - left_padded < 0: |
|
right_padded = right_padded + left_padded - loose_box_x0 |
|
left_padded = loose_box_x0 |
|
if loose_box_x1 + right_padded > ori_width: |
|
left_padded = left_padded + loose_box_x1 + right_padded - ori_width |
|
right_padded = ori_width - loose_box_x1 |
|
loose_box_x0 = int(loose_box_x0 - left_padded) |
|
loose_box_x1 = int(loose_box_x1 + right_padded) |
|
loose_box_x0 = loose_box_x0 if loose_box_x0 > 0 else 0 |
|
loose_box_x1 = loose_box_x1 if loose_box_x1 < ori_width else ori_width |
|
if loose_box_h < 256: |
|
padded_length_h = 256 - loose_box_h |
|
top_padded = padded_length_h // 2 |
|
bottom_padded = padded_length_h - top_padded |
|
if loose_box_y0 - top_padded < 0: |
|
bottom_padded = bottom_padded + top_padded - loose_box_y0 |
|
top_padded = loose_box_y0 |
|
if loose_box_y1 + bottom_padded > ori_height: |
|
top_padded = top_padded + loose_box_y1 + bottom_padded - ori_height |
|
bottom_padded = ori_height - loose_box_y1 |
|
loose_box_y0 = int(loose_box_y0 - top_padded) |
|
loose_box_y1 = int(loose_box_y1 + bottom_padded) |
|
loose_box_y0 = loose_box_y0 if loose_box_y0 > 0 else 0 |
|
loose_box_y1 = loose_box_y1 if loose_box_y1 < ori_height else ori_height |
|
|
|
loose_box_w = loose_box_x1 - loose_box_x0 |
|
loose_box_h = loose_box_y1 - loose_box_y0 |
|
if loose_box_w > loose_box_h: |
|
padded_length_h = loose_box_w - loose_box_h |
|
top_padded = padded_length_h // 2 |
|
bottom_padded = padded_length_h - top_padded |
|
if loose_box_y0 - top_padded < 0: |
|
bottom_padded = bottom_padded + top_padded - loose_box_y0 |
|
top_padded = loose_box_y0 |
|
if loose_box_y1 + bottom_padded > ori_height: |
|
top_padded = top_padded + loose_box_y1 + bottom_padded - ori_height |
|
bottom_padded = ori_height - loose_box_y1 |
|
loose_box_y0 = int(loose_box_y0 - top_padded) |
|
loose_box_y1 = int(loose_box_y1 + bottom_padded) |
|
loose_box_y0 = loose_box_y0 if loose_box_y0 > 0 else 0 |
|
loose_box_y1 = loose_box_y1 if loose_box_y1 < ori_height else ori_height |
|
elif loose_box_h > loose_box_w: |
|
padded_length_w = loose_box_h - loose_box_w |
|
left_padded = padded_length_w // 2 |
|
right_padded = padded_length_w - left_padded |
|
if loose_box_x0 - left_padded < 0: |
|
right_padded = right_padded + left_padded - loose_box_x0 |
|
left_padded = loose_box_x0 |
|
if loose_box_x1 + right_padded > ori_width: |
|
left_padded = left_padded + loose_box_x1 + right_padded - ori_width |
|
right_padded = ori_width - loose_box_x1 |
|
loose_box_x0 = int(loose_box_x0 - left_padded) |
|
loose_box_x1 = int(loose_box_x1 + right_padded) |
|
loose_box_x0 = loose_box_x0 if loose_box_x0 > 0 else 0 |
|
loose_box_x1 = loose_box_x1 if loose_box_x1 < ori_width else ori_width |
|
|
|
loose_box_w = loose_box_x1 - loose_box_x0 |
|
loose_box_h = loose_box_y1 - loose_box_y0 |
|
image_patch = ori_image[loose_box_y0:loose_box_y1, loose_box_x0:loose_box_x1, :] |
|
ori_image_patch_h, ori_image_patch_w = image_patch.shape[:2] |
|
root_mask_patch = root_mask[:, loose_box_y0:loose_box_y1, loose_box_x0:loose_box_x1] |
|
|
|
|
|
if loose_box_w > loose_box_h: |
|
target_w = 1024 |
|
target_h = int(loose_box_h / loose_box_w * target_w) |
|
else: |
|
target_h = 1024 |
|
target_w = int(loose_box_w / loose_box_h * target_h) |
|
image_patch = cv2.resize(image_patch, dsize=(target_w, target_h), interpolation=cv2.INTER_LINEAR) |
|
root_mask_patch = torch.nn.functional.interpolate(root_mask_patch[None].to(torch.float32), size=(target_h, target_w), mode="bilinear") |
|
root_mask_patch = (root_mask_patch[0] > 0.5).to(torch.int8) |
|
|
|
sam_predictor.set_image(image_patch) |
|
|
|
|
|
root_bbox_patch = torchvision.ops.masks_to_boxes(root_mask_patch) |
|
x0, y0, x1, y1 = root_bbox_patch[0].numpy().tolist() |
|
ret_points = sample_points([x0, y0, x1 - x0, y1 - y0], root_mask_patch[0], min_points=3, max_points=16, dense_max_points=32) |
|
ret_points_list = [list(point) for point in ret_points] |
|
point_coords = torch.tensor(ret_points_list, device=sam_predictor.device).unsqueeze(1) |
|
point_labels = torch.ones(size=(point_coords.shape[0], 1), dtype=torch.int, device=sam_predictor.device) |
|
|
|
|
|
generated_annos = sam_auto_mask_generator.generate(image_patch) |
|
auto_sam_masks, auto_iou_scores = [], [] |
|
for object_anno in generated_annos: |
|
object_mask = object_anno["segmentation"] |
|
if isinstance(object_mask["counts"], list): |
|
object_mask = mask_util.frPyObjects(object_mask, object_mask["size"][0], object_mask["size"][1]) |
|
mask = mask_util.decode(object_mask) |
|
mask = mask.astype(np.uint8).squeeze() |
|
auto_sam_masks.append(torch.from_numpy(mask)) |
|
auto_iou_scores.append(object_anno['predicted_iou']) |
|
auto_sam_masks = torch.stack(auto_sam_masks) |
|
auto_iou_scores = torch.as_tensor(auto_iou_scores) |
|
|
|
part_masks, part_masks_score, _ = sam_predictor.predict_torch( |
|
point_coords=point_coords, |
|
point_labels=point_labels, |
|
boxes=None, |
|
multimask_output=True, |
|
) |
|
batch_size, num_masks_per_input = part_masks.shape[:2] |
|
|
|
print(part_masks.device) |
|
|
|
|
|
part_masks_area = mask_area(part_masks.flatten(0, 1), chunk_size=50, chunk_mode=True) |
|
part_masks_area = part_masks_area.reshape(batch_size, num_masks_per_input) |
|
part_masks_idx = torch.argmin(part_masks_area, dim=1) |
|
part_masks = torch.gather(part_masks, dim=1, index=part_masks_idx) |
|
|
|
print(part_masks.shape) |
|
|
|
part_masks_score = torch.gather(part_masks_score, dim=1, index=part_masks_idx) |
|
part_masks = part_masks[part_masks_score > 0.9] |
|
|
|
print(part_masks.shape) |
|
|
|
auto_sam_masks = auto_sam_masks[auto_iou_scores > 0.9] |
|
part_masks = torch.cat([part_masks, auto_sam_masks], dim=0) |
|
part_masks_score = torch.cat([part_masks_score[part_masks_score > 0.9], auto_iou_scores[auto_iou_scores > 0.9]], dim=0) |
|
|
|
|
|
sorted_indices = sorted(range(len(part_masks)), key=lambda k: part_masks_score[k], reverse=True) |
|
sorted_part_masks = torch.stack([part_masks[idx] for idx in sorted_indices], dim=0) |
|
|
|
|
|
downsampled_part_masks = torch.nn.functional.interpolate(sorted_part_masks[None], size=(target_h//4, target_w//4), mode="bilinear") |
|
downsampled_part_masks = (downsampled_part_masks[0] > 0.5).to(sorted_part_masks.dtype).to("cuda") |
|
|
|
intersection, union = mask_iou(downsampled_part_masks, chunk_size=50, chunk_mode=True) |
|
mask_iou_matrix = intersection / union |
|
|
|
num_instances = len(mask_iou_matrix) |
|
keep = [True] * num_instances |
|
for ins_i in range(num_instances): |
|
if not keep[ins_i]: |
|
continue |
|
for ins_j in range(ins_i, num_instances): |
|
if ins_j == ins_i: |
|
continue |
|
if mask_iou_matrix[ins_i, ins_j] > 0.8: |
|
keep[ins_j] = False |
|
|
|
|
|
downsampled_root_mask_patch = torch.nn.functional.interpolate(root_mask_patch[None].to(torch.float32), size=(target_h//4, target_w//4), mode="bilinear") |
|
downsampled_root_mask_patch = (downsampled_root_mask_patch[0] > 0.5).to(root_mask_patch.dtype).to("cuda") |
|
|
|
intersection, union = mask_iou_v2(downsampled_root_mask_patch, downsampled_part_masks, chunk_size=50, chunk_mode="bilinear") |
|
downsampled_part_masks_area = mask_area(downsampled_part_masks, chunk_mode=True, chunk_size=50) |
|
mask_iou = intersection[0] / union[0] |
|
mask_roc = intersection[0] / downsampled_part_masks_area |
|
|
|
maybe_is_part = (mask_iou < 0.8) & (mask_roc > 0.95) & torch.as_tensor(keep) |
|
|
|
if not torch.any(maybe_is_part): |
|
continue |
|
|
|
left_part_masks = sorted_part_masks[maybe_is_part] |
|
left_part_masks = torch.nn.functional.interpolate(left_part_masks[None].to(torch.float32), size=(ori_image_patch_h, ori_image_patch_w), mode="bilinear") |
|
left_part_masks = (left_part_masks[0] > 0.5).to(root_mask.dtype).to(root_mask.device) |
|
full_size_part_masks = torch.zeros_like(root_mask).repeat(left_part_masks.shape[0], 1, 1) |
|
full_size_part_masks[:, loose_box_y0:loose_box_y1, loose_box_x0:loose_box_x1] = left_part_masks |
|
full_size_part_masks = full_size_part_masks.cpu().numpy() |
|
|
|
save_json_results = [] |
|
for part_idx, mask in enumerate(full_size_part_masks): |
|
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0] |
|
rle["counts"] = rle["counts"].decode("utf-8") |
|
save_json_results.append({ |
|
"root_id": root_ins_id, |
|
"part_id": part_idx+1, |
|
"segmentation": rle, |
|
}) |
|
|
|
with open(os.path.join(save_path, file_name+'.json'), 'w') as f: |
|
json.dump(save_json_results, f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|