|
import os |
|
from pycocotools import mask as mask_util |
|
import json |
|
import numpy as np |
|
import cv2 |
|
from distinctipy import distinctipy |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
from types import MethodType |
|
import json |
|
import random |
|
import sys |
|
|
|
import torch |
|
import torchvision |
|
from detectron2.data import MetadataCatalog |
|
from detectron2.structures import BitMasks, PolygonMasks |
|
from detectron2.utils.visualizer import ColorMode, Visualizer |
|
from detectron2.data.detection_utils import read_image |
|
|
|
from fvcore.common.timer import Timer |
|
|
|
from third_parts.APE.build_ape import build_ape_predictor |
|
from third_parts.recognize_anything.build_ram_plus import build_ram_predictor |
|
from third_parts.segment_anything import build_sam_vit_h, SamPredictor, SamAutomaticMaskGenerator |
|
|
|
def show_mask(mask, ax, random_color=False): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
color = np.array([30/255, 144/255, 255/255, 0.6]) |
|
h, w = mask.shape[-2:] |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
ax.imshow(mask_image) |
|
|
|
|
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
|
best_ratio_diff = float('inf') |
|
best_ratio = (1, 1) |
|
area = width * height |
|
for ratio in target_ratios: |
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
if ratio_diff < best_ratio_diff: |
|
best_ratio_diff = ratio_diff |
|
best_ratio = ratio |
|
elif ratio_diff == best_ratio_diff: |
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
|
best_ratio = ratio |
|
return best_ratio |
|
|
|
def sample_points(box, mask, min_points=3, max_points=16): |
|
x0, y0, w, h = box |
|
aspect_ratio = w / h |
|
|
|
|
|
target_ratios = set( |
|
(i, j) for n in range(min_points, max_points + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
|
i * j <= max_points and i * j >= min_points) |
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio( |
|
aspect_ratio, target_ratios, w, h, 50) |
|
width_bin = w / target_aspect_ratio[0] |
|
height_bin = h / target_aspect_ratio[1] |
|
|
|
ret_points = [] |
|
for wi in range(target_aspect_ratio[0]): |
|
xi = x0 + (wi+0.5) * width_bin |
|
for hi in range(target_aspect_ratio[1]): |
|
yi = y0 + (hi+0.5) * height_bin |
|
if mask[int(yi), int(xi)] > 0: |
|
ret_points.append((xi, yi)) |
|
|
|
|
|
temp_points = [] |
|
for wi in range(int(x0), int(x0+w)): |
|
for hi in range(int(y0), int(y0+h)): |
|
if mask[int(hi), int(wi)] > 0: |
|
temp_points.append((wi, hi)) |
|
if len(temp_points)//max_points < 1: |
|
uniform_indices = list(range(0, len(temp_points))) |
|
else: |
|
uniform_indices = list(range(0, len(temp_points), len(temp_points)//max_points)) |
|
additional_points = [temp_points[uniform_idx] for uniform_idx in uniform_indices[1:-1]] |
|
|
|
ret_points = ret_points + additional_points |
|
return ret_points |
|
|
|
def mask_iou(masks, chunk_size=50, chunk_mode=False): |
|
masks1 = masks.unsqueeze(1).char() |
|
masks2 = masks.unsqueeze(0).char() |
|
|
|
if not chunk_mode: |
|
intersection = (masks1 * masks2) |
|
union = (masks1 + masks2 - intersection).sum(-1).sum(-1) |
|
intersection = intersection.sum(-1).sum(-1) |
|
return intersection, union |
|
|
|
def chunk_mask_iou(_chunk_size=50): |
|
|
|
num_chunks = masks1.shape[0] // _chunk_size |
|
if masks1.shape[0] % _chunk_size > 0: |
|
num_chunks += 1 |
|
|
|
row_chunks_intersection, row_chunks_union = [], [] |
|
for row_idx in range(num_chunks): |
|
col_chunks_intersection, col_chunks_union = [], [] |
|
masks1_chunk = masks1[row_idx*_chunk_size:(row_idx+1)*_chunk_size] |
|
for col_idx in range(num_chunks): |
|
masks2_chunk = masks2[:, col_idx*_chunk_size:(col_idx+1)*_chunk_size] |
|
try: |
|
intersection = masks1_chunk * masks2_chunk |
|
temp_sum = masks1_chunk + masks2_chunk |
|
union = (temp_sum - intersection).sum(-1).sum(-1) |
|
intersection = intersection.sum(-1).sum(-1) |
|
except torch.cuda.OutOfMemoryError: |
|
return False, None, None |
|
col_chunks_intersection.append(intersection) |
|
col_chunks_union.append(union) |
|
row_chunks_intersection.append(torch.cat(col_chunks_intersection, dim=1)) |
|
row_chunks_union.append(torch.cat(col_chunks_union, dim=1)) |
|
intersection = torch.cat(row_chunks_intersection, dim=0) |
|
union = torch.cat(row_chunks_union, dim=0) |
|
return True, intersection, union |
|
|
|
for c_size in [chunk_size, chunk_size//2, chunk_size//4]: |
|
is_ok, intersection, union = chunk_mask_iou(c_size) |
|
if not is_ok: |
|
continue |
|
return intersection, union |
|
|
|
def mask_iou_v2(masks1, masks2, chunk_size=50, chunk_mode=False): |
|
masks1 = masks1.unsqueeze(1).char() |
|
masks2 = masks2.unsqueeze(0).char() |
|
|
|
if not chunk_mode: |
|
intersection = (masks1 * masks2) |
|
union = (masks1 + masks2 - intersection).sum(-1).sum(-1) |
|
intersection = intersection.sum(-1).sum(-1) |
|
|
|
return intersection, union |
|
|
|
def chunk_mask_iou(_chunk_size=50): |
|
num_chunks1 = masks1.shape[0] // _chunk_size |
|
if masks1.shape[0] % _chunk_size > 0: |
|
num_chunks1 += 1 |
|
|
|
num_chunks2 = masks2.shape[1] // _chunk_size |
|
if masks2.shape[0] % _chunk_size > 0: |
|
num_chunks2 += 1 |
|
|
|
row_chunks_intersection, row_chunks_union = [], [] |
|
for row_idx in range(num_chunks1): |
|
col_chunks_intersection, col_chunks_union = [], [] |
|
masks1_chunk = masks1[row_idx*_chunk_size:(row_idx+1)*_chunk_size] |
|
for col_idx in range(num_chunks2): |
|
masks2_chunk = masks2[:, col_idx*_chunk_size:(col_idx+1)*_chunk_size] |
|
try: |
|
intersection = masks1_chunk * masks2_chunk |
|
temp_sum = masks1_chunk + masks2_chunk |
|
union = (temp_sum - intersection).sum(-1).sum(-1) |
|
intersection = intersection.sum(-1).sum(-1) |
|
except torch.cuda.OutOfMemoryError: |
|
return False, None, None |
|
col_chunks_intersection.append(intersection) |
|
col_chunks_union.append(union) |
|
row_chunks_intersection.append(torch.cat(col_chunks_intersection, dim=1)) |
|
row_chunks_union.append(torch.cat(col_chunks_union, dim=1)) |
|
intersection = torch.cat(row_chunks_intersection, dim=0) |
|
union = torch.cat(row_chunks_union, dim=0) |
|
return True, intersection, union |
|
|
|
for c_size in [chunk_size, chunk_size//2, chunk_size//4]: |
|
is_ok, intersection, union = chunk_mask_iou(c_size) |
|
if not is_ok: |
|
continue |
|
return intersection, union |
|
|
|
return intersection, union |
|
|
|
|
|
def mask_area(masks, chunk_size=50, chunk_mode=False): |
|
if not chunk_mode: |
|
return masks.sum(-1).sum(-1) |
|
|
|
num_chunks = masks.shape[0] // chunk_size |
|
if masks.shape[0] % chunk_size > 0: |
|
num_chunks += 1 |
|
|
|
areas = [] |
|
for i in range(num_chunks): |
|
masks_i = masks[i*chunk_size:(i+1)*chunk_size] |
|
areas.append(masks_i.sum(-1).sum(-1)) |
|
return torch.cat(areas, dim=0) |
|
|
|
|
|
from detectron2.utils.visualizer import GenericMask |
|
import matplotlib.colors as mplc |
|
def draw_instance_predictions_cache(self, labels, np_masks, jittering: bool = True): |
|
""" |
|
Draw instance-level prediction results on an image. |
|
|
|
Args: |
|
predictions (Instances): the output of an instance detection/segmentation |
|
model. Following fields will be used to draw: |
|
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). |
|
jittering: if True, in color mode SEGMENTATION, randomly jitter the colors per class |
|
to distinguish instances from the same class |
|
|
|
Returns: |
|
output (VisImage): image object with visualizations. |
|
""" |
|
boxes = None |
|
scores = None |
|
classes = None |
|
keypoints = None |
|
|
|
masks = [GenericMask(x, self.output.height, self.output.width) for x in np_masks] |
|
|
|
|
|
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): |
|
colors = ( |
|
[self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes] |
|
if jittering |
|
else [ |
|
tuple(mplc.to_rgb([x / 255 for x in self.metadata.thing_colors[c]])) |
|
for c in classes |
|
] |
|
) |
|
|
|
alpha = 0.8 |
|
else: |
|
colors = None |
|
alpha = 0.5 |
|
|
|
self.overlay_instances( |
|
masks=masks, |
|
boxes=boxes, |
|
labels=labels, |
|
keypoints=keypoints, |
|
assigned_colors=colors, |
|
alpha=alpha, |
|
) |
|
return self.output |
|
|
|
|
|
|
|
|
|
|
|
def merge_sa1b_image(image_file, anno_file, save_path, generated_annos, visualize=False): |
|
file_name = os.path.basename(image_file).split('.')[0] |
|
|
|
if anno_file is not None: |
|
with open(anno_file, 'r') as f: |
|
json_results = json.load(f) |
|
generated_annos = json_results["annotations"] |
|
assert generated_annos is not None, "Provide valid annotation file or generated_annos from sam automatic generator." |
|
|
|
_all_sam_masks, predicted_iou_scores = [], [] |
|
for object_anno in generated_annos: |
|
object_mask = object_anno["segmentation"] |
|
if isinstance(object_mask["counts"], list): |
|
object_mask = mask_util.frPyObjects(object_mask, object_mask["size"][0], object_mask["size"][1]) |
|
mask = mask_util.decode(object_mask) |
|
mask = mask.astype(np.uint8).squeeze() |
|
_all_sam_masks.append(torch.from_numpy(mask)) |
|
predicted_iou_scores.append(object_anno['predicted_iou']) |
|
|
|
|
|
sorted_idx = sorted(range(len(predicted_iou_scores)), key=lambda k: predicted_iou_scores[k], reverse=True) |
|
all_sam_masks = [] |
|
for idx in sorted_idx: |
|
all_sam_masks.append(_all_sam_masks[idx]) |
|
|
|
all_sam_masks = torch.stack(all_sam_masks) |
|
ori_height, ori_width = all_sam_masks.shape[-2:] |
|
downsampled_sam_masks = torch.nn.functional.interpolate(all_sam_masks[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") |
|
downsampled_sam_masks = (downsampled_sam_masks[0] > 0.5).to(all_sam_masks.dtype).to("cuda") |
|
|
|
intersection, union = mask_iou(downsampled_sam_masks, chunk_size=100, chunk_mode=True) |
|
|
|
mask_iou_matrix = intersection / union |
|
|
|
|
|
num_instances = len(mask_iou_matrix) |
|
keep = [True] * num_instances |
|
for ins_i in range(num_instances): |
|
if not keep[ins_i]: |
|
continue |
|
for ins_j in range(ins_i, num_instances): |
|
if ins_j == ins_i: |
|
continue |
|
if mask_iou_matrix[ins_i, ins_j] > 0.8: |
|
keep[ins_j] = False |
|
|
|
|
|
|
|
area = mask_area(downsampled_sam_masks, chunk_mode=True, chunk_size=100) |
|
roc = intersection / area[:, None] |
|
for ins_i in range(num_instances): |
|
if not keep[ins_i]: |
|
continue |
|
for ins_j in range(num_instances): |
|
if ins_i == ins_j: |
|
continue |
|
if not keep[ins_j]: |
|
continue |
|
if roc[ins_i, ins_j] > 0.8: |
|
keep[ins_i] = False |
|
break |
|
|
|
left_masks = [all_sam_masks[ins_i] for ins_i in range(len(keep)) if keep[ins_i]] |
|
left_tags = ['object' for _ in range(len(left_masks))] |
|
|
|
unique_tags = list(set(left_tags)) |
|
text_prompt = ','.join(unique_tags) |
|
metadata = MetadataCatalog.get("__unused_ape_" + text_prompt) |
|
metadata.thing_classes = unique_tags |
|
metadata.stuff_classes = unique_tags |
|
|
|
if not visualize: |
|
return torch.stack(left_masks) |
|
|
|
|
|
def run_on_image_v2(image_file, anno_file, save_path, ram_predictor, ape_predictor, sam_predictor, sam_auto_mask_generator, visualize=False): |
|
if not os.path.exists(image_file): |
|
return None |
|
file_name = os.path.basename(image_file).split('.')[0] |
|
if (anno_file is None) or (not os.path.exists(anno_file)): |
|
image = cv2.imread(image_file) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
generated_annos = sam_auto_mask_generator.generate(image) |
|
|
|
sam_masks = merge_sa1b_image(image_file, None, save_path, generated_annos, visualize=False) |
|
else: |
|
sam_masks = merge_sa1b_image(image_file, anno_file, save_path, None, visualize=False) |
|
|
|
ape_masks, ape_tags = run_on_image(image_file, save_path, ram_predictor, ape_predictor, sam_predictor, visualize=False) |
|
if ape_masks is None: |
|
return None |
|
|
|
sam_image = cv2.imread(image_file) |
|
ori_height, ori_width = sam_image.shape[:2] |
|
sam_image = cv2.cvtColor(sam_image, cv2.COLOR_BGR2RGB) |
|
sam_predictor.set_image(sam_image) |
|
|
|
ori_height, ori_width = sam_masks.shape[-2:] |
|
downsampled_sam_masks = torch.nn.functional.interpolate(sam_masks[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") |
|
downsampled_sam_masks = (downsampled_sam_masks[0] > 0.5).to(sam_masks.dtype).to("cuda") |
|
|
|
downsampled_ape_masks = torch.nn.functional.interpolate(ape_masks[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") |
|
downsampled_ape_masks = (downsampled_ape_masks[0] > 0.5).to(ape_masks.dtype).to("cuda") |
|
|
|
sam_ape_masks_intersection, sam_ape_masks_union = mask_iou_v2(downsampled_sam_masks, downsampled_ape_masks, chunk_size=100, chunk_mode=True) |
|
|
|
sam_ape_masks_iou = sam_ape_masks_intersection / sam_ape_masks_union |
|
|
|
sam_area = mask_area(downsampled_sam_masks, chunk_mode=True, chunk_size=100) |
|
sam_masks_roc = sam_ape_masks_intersection / sam_area[:, None] |
|
|
|
sam_boxes = torchvision.ops.masks_to_boxes(sam_masks) |
|
ape_boxes = torchvision.ops.masks_to_boxes(ape_masks) |
|
|
|
first_round_masks = [] |
|
iou_target_indices = torch.argmax(sam_ape_masks_iou, dim=1) |
|
roc_target_indices = torch.argmax(sam_masks_roc, dim=1) |
|
for sam_idx in range(downsampled_sam_masks.shape[0]): |
|
iou_tgt_idx = iou_target_indices[sam_idx] |
|
roc_tgt_idx = roc_target_indices[sam_idx] |
|
|
|
if sam_ape_masks_iou[sam_idx, iou_tgt_idx] > 0.8: |
|
first_round_masks.append(sam_masks[sam_idx]) |
|
elif sam_masks_roc[sam_idx, roc_tgt_idx] > 0.8: |
|
|
|
box_x1, box_y1, box_x2, box_y2 = sam_boxes[sam_idx] |
|
box_w = box_x2 - box_x1 |
|
box_h = box_y2 - box_y1 |
|
ret_points = sample_points([box_x1, box_y1, box_w, box_h], sam_masks[sam_idx], min_points=1, max_points=3) |
|
|
|
if len(ret_points) == 0 : |
|
first_round_masks.append(sam_masks[sam_idx]) |
|
else: |
|
point_labels = [1 for _ in range(len(ret_points))] |
|
temp_masks, scores, _ = sam_predictor.predict( |
|
point_coords=np.array(ret_points), |
|
point_labels=np.array(point_labels), |
|
multimask_output=True, |
|
) |
|
|
|
temp_masks = torch.from_numpy(temp_masks) |
|
downsampled_temp_masks = torch.nn.functional.interpolate(temp_masks[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") |
|
downsampled_temp_masks = (downsampled_temp_masks[0] > 0.5).to(temp_masks.dtype).to("cuda") |
|
downsampled_ape_mask = downsampled_ape_masks[roc_tgt_idx][None] |
|
ape_temp_masks_intersection, ape_temp_masks_union = mask_iou_v2(downsampled_ape_mask, downsampled_temp_masks) |
|
ape_temp_masks_iou = ape_temp_masks_intersection / ape_temp_masks_union |
|
iou_temp_indices = torch.argmax(ape_temp_masks_iou, dim=1) |
|
iou_temp_idx = iou_temp_indices[0] |
|
if ape_temp_masks_iou[0, iou_temp_idx] > 0.8 and scores[iou_temp_idx] > 0.9: |
|
first_round_masks.append(temp_masks[iou_temp_idx]) |
|
else: |
|
first_round_masks.append(sam_masks[sam_idx]) |
|
else: |
|
|
|
box_x1, box_y1, box_x2, box_y2 = sam_boxes[sam_idx] |
|
box_w = box_x2 - box_x1 |
|
box_h = box_y2 - box_y1 |
|
ret_points = sample_points([box_x1, box_y1, box_w, box_h], sam_masks[sam_idx], min_points=1, max_points=3) |
|
|
|
if len(ret_points) == 0: |
|
first_round_masks.append(sam_masks[sam_idx]) |
|
else: |
|
point_labels = [1 for _ in range(len(ret_points))] |
|
temp_masks, scores, _ = sam_predictor.predict( |
|
point_coords=np.array(ret_points), |
|
point_labels=np.array(point_labels), |
|
multimask_output=True, |
|
) |
|
|
|
temp_masks = torch.from_numpy(temp_masks) |
|
temp_masks_area = temp_masks.sum(-1).sum(-1) |
|
tgt_idx = torch.argmax(temp_masks_area) |
|
if scores[tgt_idx] > 0.9: |
|
first_round_masks.append(temp_masks[tgt_idx]) |
|
else: |
|
first_round_masks.append(sam_masks[sam_idx]) |
|
|
|
|
|
ape_sam_masks_intersection, ape_sam_masks_union = sam_ape_masks_intersection.transpose(0, 1), sam_ape_masks_union.transpose(0, 1) |
|
|
|
ape_area = mask_area(downsampled_ape_masks, chunk_mode=True, chunk_size=100) |
|
ape_masks_roc = ape_sam_masks_intersection / ape_area[:, None] |
|
roc_target_indices = torch.argmax(ape_masks_roc, dim=1) |
|
for ape_idx in range(ape_masks.shape[0]): |
|
roc_tgt_idx = roc_target_indices[ape_idx] |
|
if ape_masks_roc[ape_idx, roc_tgt_idx] < 0.2: |
|
if sam_masks_roc[roc_tgt_idx, ape_idx] < 0.2: |
|
box_x1, box_y1, box_x2, box_y2 = ape_boxes[ape_idx] |
|
box_w = box_x2 - box_x1 |
|
box_h = box_y2 - box_y1 |
|
ret_points = sample_points([box_x1, box_y1, box_w, box_h], ape_masks[ape_idx], min_points=3, max_points=16) |
|
|
|
if len(ret_points) == 0: |
|
first_round_masks.append(ape_masks[ape_idx]) |
|
else: |
|
point_labels = [1 for _ in range(len(ret_points))] |
|
temp_masks, scores, _ = sam_predictor.predict( |
|
point_coords=np.array(ret_points), |
|
point_labels=np.array(point_labels), |
|
multimask_output=False, |
|
) |
|
temp_masks = torch.from_numpy(temp_masks) |
|
if scores[0] > 0.9: |
|
first_round_masks.append(temp_masks[0]) |
|
else: |
|
first_round_masks.append(ape_masks[ape_idx]) |
|
else: |
|
|
|
box_x1, box_y1, box_x2, box_y2 = ape_boxes[ape_idx] |
|
box_w = box_x2 - box_x1 |
|
box_h = box_y2 - box_y1 |
|
ret_points = sample_points([box_x1, box_y1, box_w, box_h], ape_masks[ape_idx], min_points=3, max_points=8) |
|
for point in ret_points: |
|
temp_masks, scores, _ = sam_predictor.predict( |
|
point_coords=np.array([point]), |
|
point_labels=np.array([1]), |
|
multimask_output=True, |
|
) |
|
temp_masks = torch.from_numpy(temp_masks) |
|
downsampled_temp_masks = torch.nn.functional.interpolate(temp_masks[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") |
|
downsampled_temp_masks = (downsampled_temp_masks[0] > 0.5).to(temp_masks.dtype).to("cuda") |
|
downsampled_ape_mask = downsampled_ape_masks[ape_idx][None] |
|
ape_temp_masks_intersection, ape_temp_masks_union = mask_iou_v2(downsampled_ape_mask, downsampled_temp_masks) |
|
ape_temp_masks_iou = ape_temp_masks_intersection / ape_temp_masks_union |
|
iou_temp_indices = torch.argmax(ape_temp_masks_iou, dim=1) |
|
iou_temp_idx = iou_temp_indices[0] |
|
if ape_temp_masks_iou[0, iou_temp_idx] > 0.8: |
|
first_round_masks.append(ape_masks[ape_idx]) |
|
|
|
|
|
|
|
first_round_scores = mask_area(torch.stack(first_round_masks), chunk_mode=True, chunk_size=100) |
|
|
|
sorted_idx = sorted(range(len(first_round_masks)), key=lambda k: first_round_scores[k], reverse=True) |
|
sorted_first_round_masks = [] |
|
for idx in sorted_idx: |
|
sorted_first_round_masks.append(first_round_masks[idx]) |
|
|
|
sorted_first_round_masks = torch.stack(sorted_first_round_masks) |
|
downsampled_first_round_masks = torch.nn.functional.interpolate(sorted_first_round_masks[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") |
|
downsampled_first_round_masks = (downsampled_first_round_masks[0] > 0.5).to(sorted_first_round_masks.dtype) |
|
|
|
intersection, union = mask_iou(downsampled_first_round_masks, chunk_mode=True, chunk_size=100) |
|
mask_iou_matrix = intersection / union |
|
|
|
|
|
num_instances = len(mask_iou_matrix) |
|
keep = [True] * num_instances |
|
for ins_i in range(num_instances): |
|
if not keep[ins_i]: |
|
continue |
|
for ins_j in range(ins_i, num_instances): |
|
if ins_j == ins_i: |
|
continue |
|
if mask_iou_matrix[ins_i, ins_j] > 0.8: |
|
keep[ins_j] = False |
|
|
|
|
|
|
|
area = mask_area(downsampled_first_round_masks, chunk_mode=True, chunk_size=100) |
|
roc = intersection / area[:, None] |
|
for ins_i in range(num_instances): |
|
if not keep[ins_i]: |
|
continue |
|
for ins_j in range(num_instances): |
|
if ins_i == ins_j: |
|
continue |
|
if not keep[ins_j]: |
|
continue |
|
if roc[ins_i, ins_j] > 0.5: |
|
keep[ins_i] = False |
|
break |
|
|
|
left_masks = [sorted_first_round_masks[ins_i] for ins_i in range(len(keep)) if keep[ins_i]] |
|
if visualize: |
|
left_tags = ['object' for _ in range(len(left_masks))] |
|
|
|
unique_tags = list(set(left_tags)) |
|
text_prompt = ','.join(unique_tags) |
|
metadata = MetadataCatalog.get("__unused_ape_" + text_prompt) |
|
metadata.thing_classes = unique_tags |
|
metadata.stuff_classes = unique_tags |
|
|
|
result_masks = torch.stack(left_masks).cpu().numpy() |
|
|
|
input_image = read_image(image_file, format="BGR") |
|
visualizer = Visualizer(input_image[:, :, ::-1], metadata, instance_mode=ColorMode.IMAGE) |
|
visualizer.draw_instance_predictions = MethodType(draw_instance_predictions_cache, visualizer) |
|
vis_output = visualizer.draw_instance_predictions(labels=left_tags, np_masks=result_masks) |
|
output_image = vis_output.get_image() |
|
output_image = Image.fromarray(output_image) |
|
|
|
final_out_path = "./work_dirs/visualize_object_level" |
|
if not os.path.exists(final_out_path): |
|
os.makedirs(final_out_path) |
|
output_image.save(os.path.join(final_out_path, file_name+'.jpg')) |
|
else: |
|
result_masks = torch.stack(left_masks).cpu().numpy() |
|
|
|
save_json_results = [] |
|
for ins_i, mask in enumerate(result_masks): |
|
rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0] |
|
rle["counts"] = rle["counts"].decode("utf-8") |
|
save_json_results.append({ |
|
"ins_id": ins_i, |
|
"segmentation": rle, |
|
}) |
|
|
|
with open(os.path.join(save_path, file_name+'.json'), 'w') as f: |
|
json.dump(save_json_results, f) |
|
|
|
|
|
|
|
|
|
|
|
def run_on_image(image_file, save_path, ram_predictor, ape_predictor, sam_predictor, visualize=False): |
|
res = ram_predictor.run_on_image(image_file_path=image_file, dynamic_resolution=True) |
|
tag_list = [] |
|
for tag_string in res[0]: |
|
tags = tag_string.split(' | ') |
|
tag_list += tags |
|
tags = list(set(tag_list)) |
|
text_prompt = ','.join(tags) |
|
|
|
output_image, json_results = ape_predictor.run_on_image( |
|
image_file, |
|
input_text=text_prompt, |
|
visualize=True, |
|
score_threhold=0.1, |
|
output_type=["instance segmentation"], |
|
) |
|
|
|
if visualize: |
|
file_name = os.path.basename(image_file).split('.')[0] |
|
raw_ape_out_path = os.path.join(save_path, 'raw_ape_out_0116') |
|
if not os.path.exists(raw_ape_out_path): |
|
os.makedirs(raw_ape_out_path) |
|
output_image.save(os.path.join(raw_ape_out_path, file_name+'.jpg')) |
|
|
|
|
|
|
|
sam_image = cv2.imread(image_file) |
|
ori_height, ori_width = sam_image.shape[:2] |
|
sam_image = cv2.cvtColor(sam_image, cv2.COLOR_BGR2RGB) |
|
sam_predictor.set_image(sam_image) |
|
|
|
new_masks_from_sam = [] |
|
correspondding_tags = [] |
|
correspondding_scores = [] |
|
for idx, item in enumerate(json_results): |
|
object_mask = item["segmentation"] |
|
if isinstance(object_mask["counts"], list): |
|
object_mask = mask_util.frPyObjects(object_mask, object_mask["size"][0], object_mask["size"][1]) |
|
mask = mask_util.decode(object_mask) |
|
mask = mask.astype(np.uint8).squeeze() |
|
|
|
box = item["bbox"] |
|
|
|
ret_points = sample_points(box, mask) |
|
|
|
if len(ret_points) == 0: |
|
continue |
|
|
|
mask_h, mask_w = object_mask["size"] |
|
input_point, input_label = [], [] |
|
for point in ret_points: |
|
_x = point[0] / mask_w * ori_width |
|
_y = point[1] / mask_h * ori_height |
|
input_point.append([int(_x), int(_y)]) |
|
input_label.append(1) |
|
|
|
masks, scores, logits = sam_predictor.predict( |
|
point_coords=np.array(input_point), |
|
point_labels=np.array(input_label), |
|
multimask_output=False |
|
) |
|
|
|
new_masks_from_sam.append(torch.from_numpy(masks)) |
|
correspondding_tags.append(item["category_name"]) |
|
correspondding_scores.append(item["score"]) |
|
if len(new_masks_from_sam) == 0: |
|
return None, None |
|
new_masks_from_sam = torch.cat(new_masks_from_sam) |
|
downsampled_new_masks_from_sam = torch.nn.functional.interpolate(new_masks_from_sam[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") |
|
downsampled_new_masks_from_sam = (downsampled_new_masks_from_sam[0] > 0.5).to(new_masks_from_sam.dtype).to("cuda") |
|
|
|
intersection, union = mask_iou(downsampled_new_masks_from_sam, chunk_mode=True, chunk_size=100) |
|
mask_iou_matrix = intersection / union |
|
|
|
|
|
num_instances = len(mask_iou_matrix) |
|
keep = [True] * num_instances |
|
for ins_i in range(num_instances): |
|
if not keep[ins_i]: |
|
continue |
|
for ins_j in range(ins_i, num_instances): |
|
if ins_j == ins_i: |
|
continue |
|
if mask_iou_matrix[ins_i, ins_j] > 0.8: |
|
keep[ins_j] = False |
|
|
|
|
|
|
|
|
|
area = mask_area(downsampled_new_masks_from_sam, chunk_mode=True, chunk_size=100) |
|
roc = intersection / area[:, None] |
|
for ins_i in range(num_instances): |
|
if not keep[ins_i]: |
|
continue |
|
for ins_j in range(num_instances): |
|
if ins_i == ins_j: |
|
continue |
|
if not keep[ins_j]: |
|
continue |
|
if roc[ins_i, ins_j] > 0.8: |
|
keep[ins_i] = False |
|
break |
|
|
|
left_masks = [new_masks_from_sam[ins_i] for ins_i in range(len(keep)) if keep[ins_i]] |
|
left_masks = torch.stack(left_masks) |
|
left_boxes = torchvision.ops.masks_to_boxes(left_masks) |
|
left_tags = [correspondding_tags[ins_i] for ins_i in range(len(keep)) if keep[ins_i]] |
|
|
|
|
|
result_mask_list = [] |
|
result_tag_list = [] |
|
ori_image = Image.open(image_file) |
|
for ins_i, ins_box in enumerate(left_boxes): |
|
ins_box = ins_box.numpy().tolist() |
|
box_w = ins_box[2] - ins_box[0] |
|
box_h = ins_box[3] - ins_box[1] |
|
loose_box_x0 = int(ins_box[0] - box_w // 4) |
|
loose_box_y0 = int(ins_box[1] - box_h // 4) |
|
loose_box_x1 = int(ins_box[2] + box_w // 4) |
|
loose_box_y1 = int(ins_box[3] + box_h // 4) |
|
loose_box_x0 = loose_box_x0 if loose_box_x0 > 0 else 0 |
|
loose_box_y0 = loose_box_y0 if loose_box_y0 > 0 else 0 |
|
loose_box_x1 = loose_box_x1 if loose_box_x1 < ori_width else ori_width |
|
loose_box_y1 = loose_box_y1 if loose_box_y1 < ori_height else ori_height |
|
|
|
loose_box_w = loose_box_x1 - loose_box_x0 |
|
loose_box_h = loose_box_y1 - loose_box_y0 |
|
assert loose_box_w >= box_w and loose_box_h >= box_h |
|
|
|
if loose_box_w < 256: |
|
padded_length_w = 256 - loose_box_w |
|
left_padded = padded_length_w // 2 |
|
right_padded = padded_length_w - left_padded |
|
if loose_box_x0 - left_padded < 0: |
|
right_padded = right_padded + left_padded - loose_box_x0 |
|
left_padded = loose_box_x0 |
|
if loose_box_x1 + right_padded > ori_width: |
|
left_padded = left_padded + loose_box_x1 + right_padded - ori_width |
|
right_padded = ori_width - loose_box_x1 |
|
loose_box_x0 = int(loose_box_x0 - left_padded) |
|
loose_box_x1 = int(loose_box_x1 + right_padded) |
|
loose_box_x0 = loose_box_x0 if loose_box_x0 > 0 else 0 |
|
loose_box_x1 = loose_box_x1 if loose_box_x1 < ori_width else ori_width |
|
if loose_box_h < 256: |
|
padded_length_h = 256 - loose_box_h |
|
top_padded = padded_length_h // 2 |
|
bottom_padded = padded_length_h - top_padded |
|
if loose_box_y0 - top_padded < 0: |
|
bottom_padded = bottom_padded + top_padded - loose_box_y0 |
|
top_padded = loose_box_y0 |
|
if loose_box_y1 + bottom_padded > ori_height: |
|
top_padded = top_padded + loose_box_y1 + bottom_padded - ori_height |
|
bottom_padded = ori_height - loose_box_y1 |
|
loose_box_y0 = int(loose_box_y0 - top_padded) |
|
loose_box_y1 = int(loose_box_y1 + bottom_padded) |
|
loose_box_y0 = loose_box_y0 if loose_box_y0 > 0 else 0 |
|
loose_box_y1 = loose_box_y1 if loose_box_y1 < ori_height else ori_height |
|
|
|
loose_box_w = loose_box_x1 - loose_box_x0 |
|
loose_box_h = loose_box_y1 - loose_box_y0 |
|
if loose_box_w > loose_box_h: |
|
padded_length_h = loose_box_w - loose_box_h |
|
top_padded = padded_length_h // 2 |
|
bottom_padded = padded_length_h - top_padded |
|
if loose_box_y0 - top_padded < 0: |
|
bottom_padded = bottom_padded + top_padded - loose_box_y0 |
|
top_padded = loose_box_y0 |
|
if loose_box_y1 + bottom_padded > ori_height: |
|
top_padded = top_padded + loose_box_y1 + bottom_padded - ori_height |
|
bottom_padded = ori_height - loose_box_y1 |
|
loose_box_y0 = int(loose_box_y0 - top_padded) |
|
loose_box_y1 = int(loose_box_y1 + bottom_padded) |
|
loose_box_y0 = loose_box_y0 if loose_box_y0 > 0 else 0 |
|
loose_box_y1 = loose_box_y1 if loose_box_y1 < ori_height else ori_height |
|
elif loose_box_h > loose_box_w: |
|
padded_length_w = loose_box_h - loose_box_w |
|
left_padded = padded_length_w // 2 |
|
right_padded = padded_length_w - left_padded |
|
if loose_box_x0 - left_padded < 0: |
|
right_padded = right_padded + left_padded - loose_box_x0 |
|
left_padded = loose_box_x0 |
|
if loose_box_x1 + right_padded > ori_width: |
|
left_padded = left_padded + loose_box_x1 + right_padded - ori_width |
|
right_padded = ori_width - loose_box_x1 |
|
loose_box_x0 = int(loose_box_x0 - left_padded) |
|
loose_box_x1 = int(loose_box_x1 + right_padded) |
|
loose_box_x0 = loose_box_x0 if loose_box_x0 > 0 else 0 |
|
loose_box_x1 = loose_box_x1 if loose_box_x1 < ori_width else ori_width |
|
|
|
image_patch = ori_image.crop((loose_box_x0, loose_box_y0, loose_box_x1, loose_box_y1)) |
|
image_patch_w, image_patch_h = image_patch.size |
|
|
|
res = ram_predictor.run_on_image(image_file_path=image_patch, dynamic_resolution=False) |
|
tag_list = [] |
|
for tag_string in res[0]: |
|
tags = tag_string.split(' | ') |
|
tag_list += tags |
|
tags = list(set(tag_list)) |
|
text_prompt = ','.join(tags) |
|
|
|
if image_patch_w > image_patch_h: |
|
rescaled_image_patch_w = 1024 |
|
rescaled_image_patch_h = int(image_patch_h / image_patch_w * 1024) |
|
else: |
|
rescaled_image_patch_h = 1024 |
|
rescaled_image_patch_w = int(image_patch_w / image_patch_h * 1024) |
|
|
|
image_patch = image_patch.resize((rescaled_image_patch_w, rescaled_image_patch_h)) |
|
output_image, json_results = ape_predictor.run_on_image( |
|
image_patch, |
|
input_text=text_prompt, |
|
visualize=True, |
|
score_threhold=0.1, |
|
output_type=["instance segmentation"], |
|
) |
|
|
|
all_masks, all_tags = [], [] |
|
for idx, item in enumerate(json_results): |
|
object_mask = item["segmentation"] |
|
if isinstance(object_mask["counts"], list): |
|
object_mask = mask_util.frPyObjects(object_mask, object_mask["size"][0], object_mask["size"][1]) |
|
mask = mask_util.decode(object_mask) |
|
mask = torch.as_tensor(mask.astype(np.uint8)) |
|
all_masks.append(mask) |
|
all_tags.append(item['category_name']) |
|
|
|
|
|
if len(all_masks) == 0: |
|
result_mask_list.append(left_masks[ins_i]) |
|
result_tag_list.append(left_tags[ins_i]) |
|
continue |
|
|
|
all_masks = torch.stack(all_masks) |
|
|
|
all_masks_ori_size = torch.nn.functional.interpolate(all_masks.unsqueeze(0), size=(image_patch_h, image_patch_w), |
|
mode='bilinear') |
|
all_masks_ori_size = all_masks_ori_size > 0.4 |
|
|
|
ori_mask_crop = left_masks[ins_i, loose_box_y0:loose_box_y1, loose_box_x0:loose_box_x1] |
|
|
|
|
|
|
|
|
|
masks1 = ori_mask_crop[None, None, :, :].char().to('cuda') |
|
masks2 = all_masks_ori_size.char().to('cuda') |
|
intersection = (masks1 * masks2) |
|
union = (masks1 + masks2 - intersection).sum(-1).sum(-1) |
|
intersection = intersection.sum(-1).sum(-1) |
|
area = masks2.sum(-1).sum(-1) |
|
|
|
masks_iou = intersection / union |
|
target_idx = torch.argmax(masks_iou, dim=1) |
|
|
|
if masks_iou[0, target_idx] < 0.8: |
|
temp_result_mask_list = [] |
|
temp_result_tag_list = [] |
|
for ins_j, mask_j_iou in enumerate(masks_iou[0]): |
|
if mask_j_iou < 0.1: |
|
continue |
|
roc_j = intersection[0, ins_j] / area[0, ins_j] |
|
if roc_j < 0.8: |
|
continue |
|
result_mask = torch.zeros((ori_height, ori_width)).to(all_masks.dtype) |
|
result_mask[loose_box_y0:loose_box_y1, loose_box_x0:loose_box_x1] = all_masks_ori_size[0, ins_j] |
|
temp_result_mask_list.append(result_mask) |
|
temp_result_tag_list.append(all_tags[ins_j]) |
|
if len(temp_result_mask_list) > 1: |
|
result_mask_list.extend(temp_result_mask_list) |
|
result_tag_list.extend(temp_result_tag_list) |
|
else: |
|
result_mask_list.append(left_masks[ins_i]) |
|
result_tag_list.append(left_tags[ins_i]) |
|
else: |
|
result_mask = torch.zeros((ori_height, ori_width)).to(all_masks.dtype) |
|
result_mask[loose_box_y0:loose_box_y1, loose_box_x0:loose_box_x1] = all_masks_ori_size[0, target_idx.item()] |
|
result_mask_list.append(result_mask) |
|
result_tag_list.append(all_tags[target_idx]) |
|
|
|
unique_tags = list(set(result_tag_list)) |
|
text_prompt = ','.join(unique_tags) |
|
metadata = MetadataCatalog.get("__unused_ape_" + text_prompt) |
|
metadata.thing_classes = unique_tags |
|
metadata.stuff_classes = unique_tags |
|
|
|
if not visualize: |
|
return torch.stack(result_mask_list), result_tag_list |
|
|
|
def main(node_id=0, local_rank=0, work_dir="./work_dirs/object_level"): |
|
|
|
global_rank_id = int(node_id * 8 + local_rank) |
|
task_file = f"./work_dirs/object_level_task/rank{global_rank_id}.json" |
|
if not os.path.exists(task_file): |
|
print(f"No task file:{task_file}") |
|
return None |
|
with open(task_file, 'r') as f: |
|
sam_images = json.load(f) |
|
|
|
ram_predictor = build_ram_predictor(override_ckpt_file="third_parts/recognize_anything/xinyu1205/recognize-anything-plus-model/ram_plus_swin_large_14m.pth") |
|
ape_predictor = build_ape_predictor(which_categories='COCO', |
|
override_ckpt_file="third_parts/APE/shenyunhang/APE/configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k_mdl_20230829_162438/model_final.pth") |
|
|
|
sam = build_sam_vit_h("third_parts/zhouyik/zt_any_visual_prompt/sam_vit_h_4b8939.pth") |
|
sam.to(device="cuda") |
|
sam_predictor = SamPredictor(sam) |
|
|
|
sam_auto_mask_generator = SamAutomaticMaskGenerator(sam) |
|
|
|
timer = Timer() |
|
past_time = 0 |
|
total_images = len(sam_images) |
|
|
|
for idx, sam_image_file in enumerate(sam_images): |
|
image_name = os.path.basename(sam_image_file).split('.')[0] |
|
dir_name = os.path.dirname(sam_image_file) |
|
sam_anno_file = os.path.join(dir_name, image_name+".json") |
|
save_dir = os.path.join(work_dir, os.path.basename(dir_name)) |
|
|
|
if os.path.exists(os.path.join(save_dir, image_name+'.json')): |
|
continue |
|
|
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir) |
|
if random.random() < 0.3: |
|
visualize=True |
|
else: |
|
visualize=False |
|
run_on_image_v2(sam_image_file, sam_anno_file, save_dir, |
|
ram_predictor, ape_predictor, sam_predictor, sam_auto_mask_generator, visualize=visualize) |
|
|
|
consume_time = "%.2f" % (timer.seconds() - past_time) |
|
past_time = timer.seconds() |
|
|
|
print(f"RANK#{local_rank}: {idx+1}/{total_images}, comsume {consume_time} seconds.") |
|
|
|
|
|
if __name__ == "__main__": |
|
work_dir, local_rank, node_id = sys.argv[1:] |
|
main(node_id=node_id, local_rank=local_rank, work_dir=work_dir) |
|
|