Spaces:
Sleeping
Sleeping
# ------------------------------------------------------------------------ | |
# Copyright (c) 2023-present, BAAI. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ------------------------------------------------------------------------ | |
"""Generate visual prompts.""" | |
import collections | |
import numpy as np | |
import numpy.random as npr | |
class VisualPrompter(object): | |
"""Generate visual prompts.""" | |
def __init__(self, image_size=1024, max_points=9, num_experts=4, padding_index=4): | |
super(VisualPrompter, self).__init__() | |
self.num_stages = 2 | |
self.max_points = max_points | |
self.point_weight = [1000] + [0] * (num_experts - 1) | |
self.image_size = image_size if isinstance(image_size, (tuple, list)) else [image_size] * 2 | |
self.padding_index = padding_index | |
self.coord_count = collections.defaultdict(int) | |
self.coords = self.labels = self.boxes_turn = None | |
self.stage_count = 0 | |
self.box_prob = 0.5 | |
def is_last_stage(self): | |
return self.stage_count == self.num_stages - 1 | |
def add_point(self, index, gt_masks, error_masks=None, num=1): | |
def sample(mask): | |
ys, xs = np.nonzero(mask) | |
if ys.shape[0] > 0: | |
idx = npr.choice(ys.shape[0], size=(num,), replace=num > ys.shape[0]) | |
return xs[idx], ys[idx] | |
return [-0.5] * num, [-0.5] * num | |
labels = [self.padding_index] * num | |
if error_masks is not None: # FP or FN point. | |
xs, ys = sample(error_masks[index]) | |
labels = gt_masks[index, ys, xs] if ys[0] >= 0 else labels | |
if labels[0] == self.padding_index: # GT point. | |
xs, ys = sample(gt_masks[index]) | |
labels = [1] * num if ys[0] >= 0 else labels | |
xs = (np.array(xs, "float32") + 0.5) * (self.image_size[1] / gt_masks.shape[2]) - 0.5 | |
ys = (np.array(ys, "float32") + 0.5) * (self.image_size[0] / gt_masks.shape[1]) - 0.5 | |
slice_index = slice(self.coord_count[index], self.coord_count[index] + num) | |
self.coords[index, slice_index] = np.vstack([xs, ys]).T | |
self.labels[index, slice_index] = labels | |
self.coord_count[index] += num | |
def add_box(self, index, gt_boxes): | |
x1, y1, x2, y2 = gt_boxes[index, :4] | |
dx1, dx2 = np.clip(npr.normal(0.0, 0.1 * (x2 - x1), (2,)), -20, 20) | |
dy1, dy2 = np.clip(npr.normal(0.0, 0.1 * (y2 - y1), (2,)), -20, 20) | |
x1, y1 = x1 + np.minimum(dx1, 0), y1 + np.minimum(dy1, 0) | |
x2, y2 = x2 + np.maximum(dx2, 0), y2 + np.maximum(dy2, 0) | |
self.coords[index, self.coord_count[index]] = (x1, y1) | |
self.coords[index, self.coord_count[index] + 1] = (x2, y2) | |
self.labels[index, self.coord_count[index]] = 2 | |
self.labels[index, self.coord_count[index] + 1] = 3 | |
self.coord_count[index] += 2 | |
def reset(self, num): | |
self.stage_count = 0 | |
self.coord_count.clear() | |
self.coords = np.full((num, self.max_points + 1, 2), -0.5, "float32") | |
self.labels = np.full((num, self.max_points + 1), self.padding_index, "int64") | |
self.boxes_turn = npr.rand(num) < self.box_prob | |
def get_prompts(self, gt_boxes, gt_masks=None, masks=None): | |
num = gt_boxes.shape[0] | |
if self.stage_count == 0: | |
self.reset(num) | |
coords = labels = error_masks = None | |
if masks is not None: | |
masks = masks.reshape(gt_masks.shape) | |
error_masks = (masks | gt_masks) ^ (masks & gt_masks) | |
num_points = 1 | |
if self.stage_count > 0: | |
num_points = npr.randint(1, self.max_points + 1 - self.stage_count) | |
if self.stage_count == 0 and self.box_prob == 0: | |
num_points = npr.randint(2, self.max_points + 1) | |
for index in range(num): | |
is_box = self.stage_count == 0 and self.boxes_turn[index] | |
if gt_masks is None or is_box: | |
self.add_box(index, gt_boxes) | |
else: | |
self.add_point(index, gt_masks, error_masks, num_points) | |
coords = self.coords[:, : 1 + self.stage_count + num_points] | |
labels = self.labels[:, : 1 + self.stage_count + num_points] | |
scores = (self.boxes_turn[:, None] - 0.5) * self.point_weight | |
return {"points": (coords, labels), "point_score": scores} | |