File size: 13,676 Bytes
1ce5e18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
from collections import defaultdict
from typing import Optional

from ..image_utils import load_image
from ..utils import (
    add_end_docstrings,
    is_torch_available,
    logging,
    requires_backends,
)
from .base import PIPELINE_INIT_ARGS, ChunkPipeline


if is_torch_available():
    import torch

    from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING_NAMES

logger = logging.get_logger(__name__)


@add_end_docstrings(PIPELINE_INIT_ARGS)
class MaskGenerationPipeline(ChunkPipeline):
    """
    Automatic mask generation for images using `SamForMaskGeneration`. This pipeline predicts binary masks for an
    image, given an image. It is a `ChunkPipeline` because you can seperate the points in a mini-batch in order to
    avoid OOM issues. Use the `points_per_batch` argument to control the number of points that will be processed at the
    same time. Default is `64`.

    The pipeline works in 3 steps:
        1. `preprocess`: A grid of 1024 points evenly separated is generated along with bounding boxes and point
           labels.
            For more details on how the points and bounding boxes are created, check the `_generate_crop_boxes`
            function. The image is also preprocessed using the `image_processor`. This function `yields` a minibatch of
            `points_per_batch`.

        2. `forward`: feeds the outputs of `preprocess` to the model. The image embedding is computed only once.
            Calls both `self.model.get_image_embeddings` and makes sure that the gradients are not computed, and the
            tensors and models are on the same device.

        3. `postprocess`: The most important part of the automatic mask generation happens here. Three steps
            are induced:
                - image_processor.postprocess_masks (run on each minibatch loop): takes in the raw output masks,
                  resizes them according
                to the image size, and transforms there to binary masks.
                - image_processor.filter_masks (on each minibatch loop): uses both `pred_iou_thresh` and
                  `stability_scores`. Also
                applies a variety of filters based on non maximum suppression to remove bad masks.
                - image_processor.postprocess_masks_for_amg applies the NSM on the mask to only keep relevant ones.

    Arguments:
        model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
            The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
            [`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
        tokenizer ([`PreTrainedTokenizer`]):
            The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
            [`PreTrainedTokenizer`].
        feature_extractor ([`SequenceFeatureExtractor`]):
            The feature extractor that will be used by the pipeline to encode the input.
        points_per_batch (*optional*, int, default to 64):
            Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
            memory.
        output_bboxes_mask (`bool`, *optional*, default to `False`):
           Whether or not to output the bounding box predictions.
        output_rle_masks (`bool`, *optional*, default to `False`):
            Whether or not to output the masks in `RLE` format

    Example:

    ```python
    >>> from transformers import pipeline

    >>> generator = pipeline(model="facebook/sam-vit-base", task="mask-generation")
    >>> outputs = generator(
    ...     "http://images.cocodataset.org/val2017/000000039769.jpg",
    ... )

    >>> outputs = generator(
    ...     "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", points_per_batch=128
    ... )
    ```

    Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)

    This segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
    `"mask-generation"`.

    See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=mask-generation).
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        requires_backends(self, "vision")
        requires_backends(self, "torch")

        if self.framework != "pt":
            raise ValueError(f"The {self.__class__} is only available in PyTorch.")

        self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)

    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        postprocess_kwargs = {}
        forward_params = {}
        # preprocess args
        if "points_per_batch" in kwargs:
            preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
        if "points_per_crop" in kwargs:
            preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
        if "crops_n_layers" in kwargs:
            preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
        if "crop_overlap_ratio" in kwargs:
            preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
        if "crop_n_points_downscale_factor" in kwargs:
            preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
        if "timeout" in kwargs:
            preprocess_kwargs["timeout"] = kwargs["timeout"]
        # postprocess args
        if "pred_iou_thresh" in kwargs:
            forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
        if "stability_score_offset" in kwargs:
            forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
        if "mask_threshold" in kwargs:
            forward_params["mask_threshold"] = kwargs["mask_threshold"]
        if "stability_score_thresh" in kwargs:
            forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
        if "crops_nms_thresh" in kwargs:
            postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
        if "output_rle_mask" in kwargs:
            postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
        if "output_bboxes_mask" in kwargs:
            postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
        return preprocess_kwargs, forward_params, postprocess_kwargs

    def __call__(self, image, *args, num_workers=None, batch_size=None, **kwargs):
        """
        Generates binary segmentation masks

        Args:
            inputs (`np.ndarray` or `bytes` or `str` or `dict`):
                Image or list of images.
            mask_threshold (`float`, *optional*, defaults to 0.0):
                Threshold to use when turning the predicted masks into binary values.
            pred_iou_thresh (`float`, *optional*, defaults to 0.88):
                A filtering threshold in `[0,1]` applied on the model's predicted mask quality.
            stability_score_thresh (`float`, *optional*, defaults to 0.95):
                A filtering threshold in `[0,1]`, using the stability of the mask under changes to the cutoff used to
                binarize the model's mask predictions.
            stability_score_offset (`int`, *optional*, defaults to 1):
                The amount to shift the cutoff when calculated the stability score.
            crops_nms_thresh (`float`, *optional*, defaults to 0.7):
                The box IoU cutoff used by non-maximal suppression to filter duplicate masks.
            crops_n_layers (`int`, *optional*, defaults to 0):
                If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of
                layers to run, where each layer has 2**i_layer number of image crops.
            crop_overlap_ratio (`float`, *optional*, defaults to `512 / 1500`):
                Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
                the image length. Later layers with more crops scale down this overlap.
            crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`):
                The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
            timeout (`float`, *optional*, defaults to None):
                The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
                the call may block forever.

        Return:
            `Dict`: A dictionary with the following keys:
                - **mask** (`PIL.Image`) -- A binary mask of the detected object as a PIL Image of shape `(width,
                  height)` of the original image. Returns a mask filled with zeros if no object is found.
                - **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of
                  the "object" described by the label and the mask.

        """
        return super().__call__(image, *args, num_workers=num_workers, batch_size=batch_size, **kwargs)

    def preprocess(
        self,
        image,
        points_per_batch=64,
        crops_n_layers: int = 0,
        crop_overlap_ratio: float = 512 / 1500,
        points_per_crop: Optional[int] = 32,
        crop_n_points_downscale_factor: Optional[int] = 1,
        timeout: Optional[float] = None,
    ):
        image = load_image(image, timeout=timeout)
        target_size = self.image_processor.size["longest_edge"]
        crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
            image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
        )
        model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")

        with self.device_placement():
            if self.framework == "pt":
                inference_context = self.get_inference_context()
                with inference_context():
                    model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
                    image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
                    model_inputs["image_embeddings"] = image_embeddings

        n_points = grid_points.shape[1]
        points_per_batch = points_per_batch if points_per_batch is not None else n_points

        if points_per_batch <= 0:
            raise ValueError(
                "Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
                "To return all points at once, set points_per_batch to None"
            )

        for i in range(0, n_points, points_per_batch):
            batched_points = grid_points[:, i : i + points_per_batch, :, :]
            labels = input_labels[:, i : i + points_per_batch]
            is_last = i == n_points - points_per_batch
            yield {
                "input_points": batched_points,
                "input_labels": labels,
                "input_boxes": crop_boxes,
                "is_last": is_last,
                **model_inputs,
            }

    def _forward(
        self,
        model_inputs,
        pred_iou_thresh=0.88,
        stability_score_thresh=0.95,
        mask_threshold=0,
        stability_score_offset=1,
    ):
        input_boxes = model_inputs.pop("input_boxes")
        is_last = model_inputs.pop("is_last")
        original_sizes = model_inputs.pop("original_sizes").tolist()
        reshaped_input_sizes = model_inputs.pop("reshaped_input_sizes").tolist()

        model_outputs = self.model(**model_inputs)

        # post processing happens here in order to avoid CPU GPU copies of ALL the masks
        low_resolution_masks = model_outputs["pred_masks"]
        masks = self.image_processor.post_process_masks(
            low_resolution_masks, original_sizes, reshaped_input_sizes, mask_threshold, binarize=False
        )
        iou_scores = model_outputs["iou_scores"]
        masks, iou_scores, boxes = self.image_processor.filter_masks(
            masks[0],
            iou_scores[0],
            original_sizes[0],
            input_boxes[0],
            pred_iou_thresh,
            stability_score_thresh,
            mask_threshold,
            stability_score_offset,
        )
        return {
            "masks": masks,
            "is_last": is_last,
            "boxes": boxes,
            "iou_scores": iou_scores,
        }

    def postprocess(
        self,
        model_outputs,
        output_rle_mask=False,
        output_bboxes_mask=False,
        crops_nms_thresh=0.7,
    ):
        all_scores = []
        all_masks = []
        all_boxes = []
        for model_output in model_outputs:
            all_scores.append(model_output.pop("iou_scores"))
            all_masks.extend(model_output.pop("masks"))
            all_boxes.append(model_output.pop("boxes"))

        all_scores = torch.cat(all_scores)
        all_boxes = torch.cat(all_boxes)
        output_masks, iou_scores, rle_mask, bounding_boxes = self.image_processor.post_process_for_mask_generation(
            all_masks, all_scores, all_boxes, crops_nms_thresh
        )

        extra = defaultdict(list)
        for output in model_outputs:
            for k, v in output.items():
                extra[k].append(v)

        optional = {}
        if output_rle_mask:
            optional["rle_mask"] = rle_mask

        if output_bboxes_mask:
            optional["bounding_boxes"] = bounding_boxes

        return {"masks": output_masks, "scores": iou_scores, **optional, **extra}