File size: 2,432 Bytes
9a393e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Numpy BoxMaskList classes and functions."""

import numpy as np
from object_detection.utils import np_box_list


class BoxMaskList(np_box_list.BoxList):
  """Convenience wrapper for BoxList with masks.

  BoxMaskList extends the np_box_list.BoxList to contain masks as well.
  In particular, its constructor receives both boxes and masks. Note that the
  masks correspond to the full image.
  """

  def __init__(self, box_data, mask_data):
    """Constructs box collection.

    Args:
      box_data: a numpy array of shape [N, 4] representing box coordinates
      mask_data: a numpy array of shape [N, height, width] representing masks
        with values are in {0,1}. The masks correspond to the full
        image. The height and the width will be equal to image height and width.

    Raises:
      ValueError: if bbox data is not a numpy array
      ValueError: if invalid dimensions for bbox data
      ValueError: if mask data is not a numpy array
      ValueError: if invalid dimension for mask data
    """
    super(BoxMaskList, self).__init__(box_data)
    if not isinstance(mask_data, np.ndarray):
      raise ValueError('Mask data must be a numpy array.')
    if len(mask_data.shape) != 3:
      raise ValueError('Invalid dimensions for mask data.')
    if mask_data.dtype != np.uint8:
      raise ValueError('Invalid data type for mask data: uint8 is required.')
    if mask_data.shape[0] != box_data.shape[0]:
      raise ValueError('There should be the same number of boxes and masks.')
    self.data['masks'] = mask_data

  def get_masks(self):
    """Convenience function for accessing masks.

    Returns:
      a numpy array of shape [N, height, width] representing masks
    """
    return self.get_field('masks')