File size: 6,807 Bytes
681fa96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# Copyright (c) Facebook, Inc. and its affiliates.
import itertools
import warnings
from typing import Any, Dict, List, Tuple, Union
import torch


class Instances:
    """

    This class represents a list of instances in an image.

    It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields".

    All fields must have the same ``__len__`` which is the number of instances.



    All other (non-field) attributes of this class are considered private:

    they must start with '_' and are not modifiable by a user.



    Some basic usage:



    1. Set/get/check a field:



       .. code-block:: python



          instances.gt_boxes = Boxes(...)

          print(instances.pred_masks)  # a tensor of shape (N, H, W)

          print('gt_masks' in instances)



    2. ``len(instances)`` returns the number of instances

    3. Indexing: ``instances[indices]`` will apply the indexing on all the fields

       and returns a new :class:`Instances`.

       Typically, ``indices`` is a integer vector of indices,

       or a binary mask of length ``num_instances``



       .. code-block:: python



          category_3_detections = instances[instances.pred_classes == 3]

          confident_detections = instances[instances.scores > 0.9]

    """

    def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
        """

        Args:

            image_size (height, width): the spatial size of the image.

            kwargs: fields to add to this `Instances`.

        """
        self._image_size = image_size
        self._fields: Dict[str, Any] = {}
        for k, v in kwargs.items():
            self.set(k, v)

    @property
    def image_size(self) -> Tuple[int, int]:
        """

        Returns:

            tuple: height, width

        """
        return self._image_size

    def __setattr__(self, name: str, val: Any) -> None:
        if name.startswith("_"):
            super().__setattr__(name, val)
        else:
            self.set(name, val)

    def __getattr__(self, name: str) -> Any:
        if name == "_fields" or name not in self._fields:
            raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
        return self._fields[name]

    def set(self, name: str, value: Any) -> None:
        """

        Set the field named `name` to `value`.

        The length of `value` must be the number of instances,

        and must agree with other existing fields in this object.

        """
        with warnings.catch_warnings(record=True):
            data_len = len(value)
        if len(self._fields):
            assert (
                len(self) == data_len
            ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
        self._fields[name] = value

    def has(self, name: str) -> bool:
        """

        Returns:

            bool: whether the field called `name` exists.

        """
        return name in self._fields

    def remove(self, name: str) -> None:
        """

        Remove the field called `name`.

        """
        del self._fields[name]

    def get(self, name: str) -> Any:
        """

        Returns the field called `name`.

        """
        return self._fields[name]

    def get_fields(self) -> Dict[str, Any]:
        """

        Returns:

            dict: a dict which maps names (str) to data of the fields



        Modifying the returned dict will modify this instance.

        """
        return self._fields

    # Tensor-like methods
    def to(self, *args: Any, **kwargs: Any) -> "Instances":
        """

        Returns:

            Instances: all fields are called with a `to(device)`, if the field has this method.

        """
        ret = Instances(self._image_size)
        for k, v in self._fields.items():
            if hasattr(v, "to"):
                v = v.to(*args, **kwargs)
            ret.set(k, v)
        return ret

    def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Instances":
        """

        Args:

            item: an index-like object and will be used to index all the fields.



        Returns:

            If `item` is a string, return the data in the corresponding field.

            Otherwise, returns an `Instances` where all fields are indexed by `item`.

        """
        if type(item) == int:
            if item >= len(self) or item < -len(self):
                raise IndexError("Instances index out of range!")
            else:
                item = slice(item, None, len(self))

        ret = Instances(self._image_size)
        for k, v in self._fields.items():
            ret.set(k, v[item])
        return ret

    def __len__(self) -> int:
        for v in self._fields.values():
            # use __len__ because len() has to be int and is not friendly to tracing
            return v.__len__()
        raise NotImplementedError("Empty Instances does not support __len__!")

    def __iter__(self):
        raise NotImplementedError("`Instances` object is not iterable!")

    @staticmethod
    def cat(instance_lists: List["Instances"]) -> "Instances":
        """

        Args:

            instance_lists (list[Instances])



        Returns:

            Instances

        """
        assert all(isinstance(i, Instances) for i in instance_lists)
        assert len(instance_lists) > 0
        if len(instance_lists) == 1:
            return instance_lists[0]

        image_size = instance_lists[0].image_size
        if not isinstance(image_size, torch.Tensor):  # could be a tensor in tracing
            for i in instance_lists[1:]:
                assert i.image_size == image_size
        ret = Instances(image_size)
        for k in instance_lists[0]._fields.keys():
            values = [i.get(k) for i in instance_lists]
            v0 = values[0]
            if isinstance(v0, torch.Tensor):
                values = torch.cat(values, dim=0)
            elif isinstance(v0, list):
                values = list(itertools.chain(*values))
            elif hasattr(type(v0), "cat"):
                values = type(v0).cat(values)
            else:
                raise ValueError("Unsupported type {} for concatenation".format(type(v0)))
            ret.set(k, values)
        return ret

    def __str__(self) -> str:
        s = self.__class__.__name__ + "("
        s += "num_instances={}, ".format(len(self))
        s += "image_height={}, ".format(self._image_size[0])
        s += "image_width={}, ".format(self._image_size[1])
        s += "fields=[{}])".format(", ".join((f"{k}: {v}" for k, v in self._fields.items())))
        return s

    __repr__ = __str__