|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Processor class for SAM. |
|
""" |
|
from copy import deepcopy |
|
from typing import Optional, Union |
|
|
|
import numpy as np |
|
|
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.tokenization_utils_base import BatchEncoding |
|
from transformers.utils import TensorType, is_tf_available, is_torch_available |
|
|
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
if is_tf_available(): |
|
import tensorflow as tf |
|
|
|
|
|
class SamProcessor(ProcessorMixin): |
|
r""" |
|
Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a |
|
single processor. |
|
|
|
[`SamProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of |
|
[`~SamImageProcessor.__call__`] for more information. |
|
|
|
Args: |
|
image_processor (`SamImageProcessor`): |
|
An instance of [`SamImageProcessor`]. The image processor is a required input. |
|
""" |
|
attributes = ["image_processor"] |
|
image_processor_class = "SamImageProcessor" |
|
|
|
def __init__(self, image_processor): |
|
super().__init__(image_processor) |
|
self.current_processor = self.image_processor |
|
self.point_pad_value = -10 |
|
self.target_size = self.image_processor.size["longest_edge"] |
|
|
|
def __call__( |
|
self, |
|
images=None, |
|
input_points=None, |
|
input_labels=None, |
|
input_boxes=None, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
**kwargs, |
|
) -> BatchEncoding: |
|
""" |
|
This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D |
|
points and bounding boxes for the model if they are provided. |
|
""" |
|
encoding_image_processor = self.process_images( |
|
images, |
|
return_tensors=return_tensors, |
|
**kwargs, |
|
) |
|
|
|
|
|
original_sizes = encoding_image_processor["original_sizes"] |
|
encoding_prompts_processor = self.process_prompts( |
|
original_sizes, |
|
input_points=input_points, |
|
input_labels=input_labels, |
|
input_boxes=input_boxes, |
|
return_tensors=return_tensors, |
|
) |
|
|
|
encoding_image_processor.update(encoding_prompts_processor) |
|
return encoding_image_processor |
|
|
|
def process_images(self, images, return_tensors, **kwargs): |
|
return self.image_processor( |
|
images, |
|
return_tensors=return_tensors, |
|
**kwargs, |
|
) |
|
|
|
def process_prompts( |
|
self, |
|
original_sizes, |
|
input_points=None, |
|
input_labels=None, |
|
input_boxes=None, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
): |
|
if hasattr(original_sizes, "numpy"): |
|
original_sizes = original_sizes.numpy() |
|
|
|
try: |
|
input_points, input_labels, input_boxes = self._check_and_preprocess_points( |
|
input_points=input_points, |
|
input_labels=input_labels, |
|
input_boxes=input_boxes, |
|
) |
|
except Exception as e: |
|
raise ValueError( |
|
f"Error when checking inputs: {e}\n" |
|
f"input_points: {input_points}\n" |
|
f"input_labels: {input_labels}\n" |
|
f"input_boxes: {input_boxes}\n" |
|
) |
|
|
|
encoding_prompts_processor = self._normalize_and_convert( |
|
original_sizes, |
|
input_points=input_points, |
|
input_labels=input_labels, |
|
input_boxes=input_boxes, |
|
return_tensors=return_tensors, |
|
) |
|
|
|
return encoding_prompts_processor |
|
|
|
def _normalize_and_convert( |
|
self, |
|
original_sizes, |
|
input_points=None, |
|
input_labels=None, |
|
input_boxes=None, |
|
return_tensors="pt", |
|
): |
|
return_dict = {} |
|
if input_points is not None: |
|
if len(original_sizes) != len(input_points): |
|
input_points = [ |
|
self._normalize_coordinates(self.target_size, point, original_sizes[0]) for point in input_points |
|
] |
|
else: |
|
input_points = [ |
|
self._normalize_coordinates(self.target_size, point, original_size) |
|
for point, original_size in zip(input_points, original_sizes) |
|
] |
|
|
|
if not all([point.shape == input_points[0].shape for point in input_points]): |
|
if input_labels is not None: |
|
input_points, input_labels = self._pad_points_and_labels(input_points, input_labels) |
|
|
|
input_points = np.array(input_points) |
|
|
|
if input_labels is not None: |
|
input_labels = np.array(input_labels) |
|
|
|
if input_boxes is not None: |
|
if len(original_sizes) != len(input_boxes): |
|
input_boxes = [ |
|
self._normalize_coordinates(self.target_size, box, original_sizes[0], is_bounding_box=True) |
|
for box in input_boxes |
|
] |
|
else: |
|
input_boxes = [ |
|
self._normalize_coordinates(self.target_size, box, original_size, is_bounding_box=True) |
|
for box, original_size in zip(input_boxes, original_sizes) |
|
] |
|
input_boxes = np.array(input_boxes) |
|
|
|
if input_boxes is not None: |
|
if return_tensors == "pt": |
|
input_boxes = torch.from_numpy(input_boxes) |
|
|
|
input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes |
|
elif return_tensors == "tf": |
|
input_boxes = tf.convert_to_tensor(input_boxes) |
|
|
|
input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes |
|
return_dict.update({"input_boxes": input_boxes}) |
|
if input_points is not None: |
|
if return_tensors == "pt": |
|
input_points = torch.from_numpy(input_points) |
|
|
|
input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points |
|
elif return_tensors == "tf": |
|
input_points = tf.convert_to_tensor(input_points) |
|
|
|
input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points |
|
return_dict.update({"input_points": input_points}) |
|
if input_labels is not None: |
|
if return_tensors == "pt": |
|
input_labels = torch.from_numpy(input_labels) |
|
|
|
input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels |
|
elif return_tensors == "tf": |
|
input_labels = tf.convert_to_tensor(input_labels) |
|
|
|
input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels |
|
return_dict.update({"input_labels": input_labels}) |
|
|
|
return return_dict |
|
|
|
def _pad_points_and_labels(self, input_points, input_labels): |
|
r""" |
|
The method pads the 2D points and labels to the maximum number of points in the batch. |
|
""" |
|
expected_nb_points = max([point.shape[0] for point in input_points]) |
|
processed_input_points = [] |
|
for i, point in enumerate(input_points): |
|
if point.shape[0] != expected_nb_points: |
|
point = np.concatenate( |
|
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0 |
|
) |
|
input_labels[i] = np.append(input_labels[i], [self.point_pad_value]) |
|
processed_input_points.append(point) |
|
input_points = processed_input_points |
|
return input_points, input_labels |
|
|
|
def _normalize_coordinates( |
|
self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False |
|
) -> np.ndarray: |
|
""" |
|
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. |
|
""" |
|
old_h, old_w = original_size |
|
new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size) |
|
coords = deepcopy(coords).astype(float) |
|
|
|
if is_bounding_box: |
|
coords = coords.reshape(-1, 2, 2) |
|
|
|
coords[..., 0] = coords[..., 0] * (new_w / old_w) |
|
coords[..., 1] = coords[..., 1] * (new_h / old_h) |
|
|
|
if is_bounding_box: |
|
coords = coords.reshape(-1, 4) |
|
|
|
return coords |
|
|
|
def _check_and_preprocess_points( |
|
self, |
|
input_points=None, |
|
input_labels=None, |
|
input_boxes=None, |
|
): |
|
r""" |
|
Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they |
|
are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`, |
|
it is converted to a `numpy.ndarray` and then to a `list`. |
|
""" |
|
if input_points is not None: |
|
if hasattr(input_points, "numpy"): |
|
input_points = input_points.numpy().tolist() |
|
|
|
if not isinstance(input_points, list) or not isinstance(input_points[0], list): |
|
raise ValueError("Input points must be a list of list of floating points.") |
|
input_points = [np.array(input_point) for input_point in input_points] |
|
else: |
|
input_points = None |
|
|
|
if input_labels is not None: |
|
if hasattr(input_labels, "numpy"): |
|
input_labels = input_labels.numpy().tolist() |
|
|
|
if not isinstance(input_labels, list) or not isinstance(input_labels[0], list): |
|
raise ValueError("Input labels must be a list of list integers.") |
|
input_labels = [np.array(label) for label in input_labels] |
|
else: |
|
input_labels = None |
|
|
|
if input_boxes is not None: |
|
if hasattr(input_boxes, "numpy"): |
|
input_boxes = input_boxes.numpy().tolist() |
|
|
|
if ( |
|
not isinstance(input_boxes, list) |
|
or not isinstance(input_boxes[0], list) |
|
or not isinstance(input_boxes[0][0], list) |
|
): |
|
raise ValueError("Input boxes must be a list of list of list of floating points.") |
|
input_boxes = [np.array(box).astype(np.float32) for box in input_boxes] |
|
else: |
|
input_boxes = None |
|
|
|
return input_points, input_labels, input_boxes |
|
|
|
@property |
|
def model_input_names(self): |
|
image_processor_input_names = self.image_processor.model_input_names |
|
return list(dict.fromkeys(image_processor_input_names)) |
|
|
|
def post_process_masks(self, *args, **kwargs): |
|
return self.image_processor.post_process_masks(*args, **kwargs) |
|
|