| | |
| | import warnings |
| | from typing import List, Sequence, Union |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from .base_data_element import BaseDataElement |
| |
|
| |
|
| | class PixelData(BaseDataElement): |
| | """Data structure for pixel-level annotations or predictions. |
| | |
| | All data items in ``data_fields`` of ``PixelData`` meet the following |
| | requirements: |
| | |
| | - They all have 3 dimensions in orders of channel, height, and width. |
| | - They should have the same height and width. |
| | |
| | Examples: |
| | >>> metainfo = dict( |
| | ... img_id=random.randint(0, 100), |
| | ... img_shape=(random.randint(400, 600), random.randint(400, 600))) |
| | >>> image = np.random.randint(0, 255, (4, 20, 40)) |
| | >>> featmap = torch.randint(0, 255, (10, 20, 40)) |
| | >>> pixel_data = PixelData(metainfo=metainfo, |
| | ... image=image, |
| | ... featmap=featmap) |
| | >>> print(pixel_data.shape) |
| | (20, 40) |
| | |
| | >>> # slice |
| | >>> slice_data = pixel_data[10:20, 20:40] |
| | >>> assert slice_data.shape == (10, 20) |
| | >>> slice_data = pixel_data[10, 20] |
| | >>> assert slice_data.shape == (1, 1) |
| | |
| | >>> # set |
| | >>> pixel_data.map3 = torch.randint(0, 255, (20, 40)) |
| | >>> assert tuple(pixel_data.map3.shape) == (1, 20, 40) |
| | >>> with self.assertRaises(AssertionError): |
| | ... # The dimension must be 3 or 2 |
| | ... pixel_data.map2 = torch.randint(0, 255, (1, 3, 20, 40)) |
| | """ |
| |
|
| | def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray]): |
| | """Set attributes of ``PixelData``. |
| | |
| | If the dimension of value is 2 and its shape meet the demand, it |
| | will automatically expand its channel-dimension. |
| | |
| | Args: |
| | name (str): The key to access the value, stored in `PixelData`. |
| | value (Union[torch.Tensor, np.ndarray]): The value to store in. |
| | The type of value must be `torch.Tensor` or `np.ndarray`, |
| | and its shape must meet the requirements of `PixelData`. |
| | """ |
| | if name in ('_metainfo_fields', '_data_fields'): |
| | if not hasattr(self, name): |
| | super().__setattr__(name, value) |
| | else: |
| | raise AttributeError(f'{name} has been used as a ' |
| | 'private attribute, which is immutable.') |
| |
|
| | else: |
| | assert isinstance(value, (torch.Tensor, np.ndarray)), \ |
| | f'Can not set {type(value)}, only support' \ |
| | f' {(torch.Tensor, np.ndarray)}' |
| |
|
| | if self.shape: |
| | assert tuple(value.shape[-2:]) == self.shape, ( |
| | 'The height and width of ' |
| | f'values {tuple(value.shape[-2:])} is ' |
| | 'not consistent with ' |
| | 'the shape of this ' |
| | ':obj:`PixelData` ' |
| | f'{self.shape}') |
| | assert value.ndim in [ |
| | 2, 3 |
| | ], f'The dim of value must be 2 or 3, but got {value.ndim}' |
| | if value.ndim == 2: |
| | value = value[None] |
| | warnings.warn('The shape of value will convert from ' |
| | f'{value.shape[-2:]} to {value.shape}') |
| | super().__setattr__(name, value) |
| |
|
| | |
| | def __getitem__(self, item: Sequence[Union[int, slice]]) -> 'PixelData': |
| | """ |
| | Args: |
| | item (Sequence[Union[int, slice]]): Get the corresponding values |
| | according to item. |
| | |
| | Returns: |
| | :obj:`PixelData`: Corresponding values. |
| | """ |
| |
|
| | new_data = self.__class__(metainfo=self.metainfo) |
| | if isinstance(item, tuple): |
| |
|
| | assert len(item) == 2, 'Only support to slice height and width' |
| | tmp_item: List[slice] = list() |
| | for index, single_item in enumerate(item[::-1]): |
| | if isinstance(single_item, int): |
| | tmp_item.insert( |
| | 0, slice(single_item, None, self.shape[-index - 1])) |
| | elif isinstance(single_item, slice): |
| | tmp_item.insert(0, single_item) |
| | else: |
| | raise TypeError( |
| | 'The type of element in input must be int or slice, ' |
| | f'but got {type(single_item)}') |
| | tmp_item.insert(0, slice(None, None, None)) |
| | item = tuple(tmp_item) |
| | for k, v in self.items(): |
| | setattr(new_data, k, v[item]) |
| | else: |
| | raise TypeError( |
| | f'Unsupported type {type(item)} for slicing PixelData') |
| | return new_data |
| |
|
| | @property |
| | def shape(self): |
| | """The shape of pixel data.""" |
| | if len(self._data_fields) > 0: |
| | return tuple(self.values()[0].shape[-2:]) |
| | else: |
| | return None |
| |
|
| | |
| |
|