|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
from typing import Any, Iterator, Optional, Tuple, Type, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
class BaseDataElement: |
|
|
"""A base data interface that supports Tensor-like and dict-like |
|
|
operations. |
|
|
|
|
|
A typical data elements refer to predicted results or ground truth labels |
|
|
on a task, such as predicted bboxes, instance masks, semantic |
|
|
segmentation masks, etc. Because groundtruth labels and predicted results |
|
|
often have similar properties (for example, the predicted bboxes and the |
|
|
groundtruth bboxes), MMEngine uses the same abstract data interface to |
|
|
encapsulate predicted results and groundtruth labels, and it is recommended |
|
|
to use different name conventions to distinguish them, such as using |
|
|
``gt_instances`` and ``pred_instances`` to distinguish between labels and |
|
|
predicted results. Additionally, we distinguish data elements at instance |
|
|
level, pixel level, and label level. Each of these types has its own |
|
|
characteristics. Therefore, MMEngine defines the base class |
|
|
``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and |
|
|
``LabelData`` inheriting from ``BaseDataElement`` to represent different |
|
|
types of ground truth labels or predictions. |
|
|
|
|
|
Another common data element is sample data. A sample data consists of input |
|
|
data (such as an image) and its annotations and predictions. In general, |
|
|
an image can have multiple types of annotations and/or predictions at the |
|
|
same time (for example, both pixel-level semantic segmentation annotations |
|
|
and instance-level detection bboxes annotations). All labels and |
|
|
predictions of a training sample are often passed between Dataset, Model, |
|
|
Visualizer, and Evaluator components. In order to simplify the interface |
|
|
between components, we can treat them as a large data element and |
|
|
encapsulate them. Such data elements are generally called XXDataSample in |
|
|
the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` |
|
|
allows `BaseDataElement` as its attribute. Such a class generally |
|
|
encapsulates all the data of a sample in the algorithm library, and its |
|
|
attributes generally are various types of data elements. For example, |
|
|
MMDetection is assigned by the BaseDataElement to encapsulate all the data |
|
|
elements of the sample labeling and prediction of a sample in the |
|
|
algorithm library. |
|
|
|
|
|
The attributes in ``BaseDataElement`` are divided into two parts, |
|
|
the ``metainfo`` and the ``data`` respectively. |
|
|
|
|
|
- ``metainfo``: Usually contains the |
|
|
information about the image such as filename, |
|
|
image_shape, pad_shape, etc. The attributes can be accessed or |
|
|
modified by dict-like or object-like operations, such as |
|
|
``.`` (for data access and modification), ``in``, ``del``, |
|
|
``pop(str)``, ``get(str)``, ``metainfo_keys()``, |
|
|
``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for |
|
|
set or change key-value pairs in metainfo). |
|
|
|
|
|
- ``data``: Annotations or model predictions are |
|
|
stored. The attributes can be accessed or modified by |
|
|
dict-like or object-like operations, such as |
|
|
``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, |
|
|
``values()``, ``items()``. Users can also apply tensor-like |
|
|
methods to all :obj:`torch.Tensor` in the ``data_fields``, |
|
|
such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, |
|
|
``to_tensor()``, ``.detach()``. |
|
|
|
|
|
Args: |
|
|
metainfo (dict, optional): A dict contains the meta information |
|
|
of single image, such as ``dict(img_shape=(512, 512, 3), |
|
|
scale_factor=(1, 1, 1, 1))``. Defaults to None. |
|
|
kwargs (dict, optional): A dict contains annotations of single image or |
|
|
model predictions. Defaults to None. |
|
|
|
|
|
Examples: |
|
|
>>> import torch |
|
|
>>> from mmengine.structures import BaseDataElement |
|
|
>>> gt_instances = BaseDataElement() |
|
|
>>> bboxes = torch.rand((5, 4)) |
|
|
>>> scores = torch.rand((5,)) |
|
|
>>> img_id = 0 |
|
|
>>> img_shape = (800, 1333) |
|
|
>>> gt_instances = BaseDataElement( |
|
|
... metainfo=dict(img_id=img_id, img_shape=img_shape), |
|
|
... bboxes=bboxes, scores=scores) |
|
|
>>> gt_instances = BaseDataElement( |
|
|
... metainfo=dict(img_id=img_id, img_shape=(640, 640))) |
|
|
|
|
|
>>> # new |
|
|
>>> gt_instances1 = gt_instances.new( |
|
|
... metainfo=dict(img_id=1, img_shape=(640, 640)), |
|
|
... bboxes=torch.rand((5, 4)), |
|
|
... scores=torch.rand((5,))) |
|
|
>>> gt_instances2 = gt_instances1.new() |
|
|
|
|
|
>>> # add and process property |
|
|
>>> gt_instances = BaseDataElement() |
|
|
>>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) |
|
|
>>> assert 'img_shape' in gt_instances.metainfo_keys() |
|
|
>>> assert 'img_shape' in gt_instances |
|
|
>>> assert 'img_shape' not in gt_instances.keys() |
|
|
>>> assert 'img_shape' in gt_instances.all_keys() |
|
|
>>> print(gt_instances.img_shape) |
|
|
(100, 100) |
|
|
>>> gt_instances.scores = torch.rand((5,)) |
|
|
>>> assert 'scores' in gt_instances.keys() |
|
|
>>> assert 'scores' in gt_instances |
|
|
>>> assert 'scores' in gt_instances.all_keys() |
|
|
>>> assert 'scores' not in gt_instances.metainfo_keys() |
|
|
>>> print(gt_instances.scores) |
|
|
tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) |
|
|
>>> gt_instances.bboxes = torch.rand((5, 4)) |
|
|
>>> assert 'bboxes' in gt_instances.keys() |
|
|
>>> assert 'bboxes' in gt_instances |
|
|
>>> assert 'bboxes' in gt_instances.all_keys() |
|
|
>>> assert 'bboxes' not in gt_instances.metainfo_keys() |
|
|
>>> print(gt_instances.bboxes) |
|
|
tensor([[0.0900, 0.0424, 0.1755, 0.4469], |
|
|
[0.8648, 0.0592, 0.3484, 0.0913], |
|
|
[0.5808, 0.1909, 0.6165, 0.7088], |
|
|
[0.5490, 0.4209, 0.9416, 0.2374], |
|
|
[0.3652, 0.1218, 0.8805, 0.7523]]) |
|
|
|
|
|
>>> # delete and change property |
|
|
>>> gt_instances = BaseDataElement( |
|
|
... metainfo=dict(img_id=0, img_shape=(640, 640)), |
|
|
... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) |
|
|
>>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) |
|
|
>>> gt_instances.img_shape # (1280, 1280) |
|
|
>>> gt_instances.bboxes = gt_instances.bboxes * 2 |
|
|
>>> gt_instances.get('img_shape', None) # (1280, 1280) |
|
|
>>> gt_instances.get('bboxes', None) # 6x4 tensor |
|
|
>>> del gt_instances.img_shape |
|
|
>>> del gt_instances.bboxes |
|
|
>>> assert 'img_shape' not in gt_instances |
|
|
>>> assert 'bboxes' not in gt_instances |
|
|
>>> gt_instances.pop('img_shape', None) # None |
|
|
>>> gt_instances.pop('bboxes', None) # None |
|
|
|
|
|
>>> # Tensor-like |
|
|
>>> cuda_instances = gt_instances.cuda() |
|
|
>>> cuda_instances = gt_instances.to('cuda:0') |
|
|
>>> cpu_instances = cuda_instances.cpu() |
|
|
>>> cpu_instances = cuda_instances.to('cpu') |
|
|
>>> fp16_instances = cuda_instances.to( |
|
|
... device=None, dtype=torch.float16, non_blocking=False, |
|
|
... copy=False, memory_format=torch.preserve_format) |
|
|
>>> cpu_instances = cuda_instances.detach() |
|
|
>>> np_instances = cpu_instances.numpy() |
|
|
|
|
|
>>> # print |
|
|
>>> metainfo = dict(img_shape=(800, 1196, 3)) |
|
|
>>> gt_instances = BaseDataElement( |
|
|
... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) |
|
|
>>> sample = BaseDataElement(metainfo=metainfo, |
|
|
... gt_instances=gt_instances) |
|
|
>>> print(sample) |
|
|
<BaseDataElement( |
|
|
META INFORMATION |
|
|
img_shape: (800, 1196, 3) |
|
|
DATA FIELDS |
|
|
gt_instances: <BaseDataElement( |
|
|
META INFORMATION |
|
|
img_shape: (800, 1196, 3) |
|
|
DATA FIELDS |
|
|
det_labels: tensor([0, 1, 2, 3]) |
|
|
) at 0x7f0ec5eadc70> |
|
|
) at 0x7f0fea49e130> |
|
|
|
|
|
>>> # inheritance |
|
|
>>> class DetDataSample(BaseDataElement): |
|
|
... @property |
|
|
... def proposals(self): |
|
|
... return self._proposals |
|
|
... @proposals.setter |
|
|
... def proposals(self, value): |
|
|
... self.set_field(value, '_proposals', dtype=BaseDataElement) |
|
|
... @proposals.deleter |
|
|
... def proposals(self): |
|
|
... del self._proposals |
|
|
... @property |
|
|
... def gt_instances(self): |
|
|
... return self._gt_instances |
|
|
... @gt_instances.setter |
|
|
... def gt_instances(self, value): |
|
|
... self.set_field(value, '_gt_instances', |
|
|
... dtype=BaseDataElement) |
|
|
... @gt_instances.deleter |
|
|
... def gt_instances(self): |
|
|
... del self._gt_instances |
|
|
... @property |
|
|
... def pred_instances(self): |
|
|
... return self._pred_instances |
|
|
... @pred_instances.setter |
|
|
... def pred_instances(self, value): |
|
|
... self.set_field(value, '_pred_instances', |
|
|
... dtype=BaseDataElement) |
|
|
... @pred_instances.deleter |
|
|
... def pred_instances(self): |
|
|
... del self._pred_instances |
|
|
>>> det_sample = DetDataSample() |
|
|
>>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) |
|
|
>>> det_sample.proposals = proposals |
|
|
>>> assert 'proposals' in det_sample |
|
|
>>> assert det_sample.proposals == proposals |
|
|
>>> del det_sample.proposals |
|
|
>>> assert 'proposals' not in det_sample |
|
|
>>> with self.assertRaises(AssertionError): |
|
|
... det_sample.proposals = torch.rand((5, 4)) |
|
|
""" |
|
|
|
|
|
def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: |
|
|
|
|
|
self._metainfo_fields: set = set() |
|
|
self._data_fields: set = set() |
|
|
|
|
|
if metainfo is not None: |
|
|
self.set_metainfo(metainfo=metainfo) |
|
|
if kwargs: |
|
|
self.set_data(kwargs) |
|
|
|
|
|
def set_metainfo(self, metainfo: dict) -> None: |
|
|
"""Set or change key-value pairs in ``metainfo_field`` by parameter |
|
|
``metainfo``. |
|
|
|
|
|
Args: |
|
|
metainfo (dict): A dict contains the meta information |
|
|
of image, such as ``img_shape``, ``scale_factor``, etc. |
|
|
""" |
|
|
assert isinstance( |
|
|
metainfo, |
|
|
dict), f'metainfo should be a ``dict`` but got {type(metainfo)}' |
|
|
meta = copy.deepcopy(metainfo) |
|
|
for k, v in meta.items(): |
|
|
self.set_field(name=k, value=v, field_type='metainfo', dtype=None) |
|
|
|
|
|
def set_data(self, data: dict) -> None: |
|
|
"""Set or change key-value pairs in ``data_field`` by parameter |
|
|
``data``. |
|
|
|
|
|
Args: |
|
|
data (dict): A dict contains annotations of image or |
|
|
model predictions. |
|
|
""" |
|
|
assert isinstance(data, |
|
|
dict), f'data should be a `dict` but got {data}' |
|
|
for k, v in data.items(): |
|
|
|
|
|
|
|
|
setattr(self, k, v) |
|
|
|
|
|
def update(self, instance: 'BaseDataElement') -> None: |
|
|
"""The update() method updates the BaseDataElement with the elements |
|
|
from another BaseDataElement object. |
|
|
|
|
|
Args: |
|
|
instance (BaseDataElement): Another BaseDataElement object for |
|
|
update the current object. |
|
|
""" |
|
|
assert isinstance( |
|
|
instance, BaseDataElement |
|
|
), f'instance should be a `BaseDataElement` but got {type(instance)}' |
|
|
self.set_metainfo(dict(instance.metainfo_items())) |
|
|
self.set_data(dict(instance.items())) |
|
|
|
|
|
def new(self, |
|
|
*, |
|
|
metainfo: Optional[dict] = None, |
|
|
**kwargs) -> 'BaseDataElement': |
|
|
"""Return a new data element with same type. If ``metainfo`` and |
|
|
``data`` are None, the new data element will have same metainfo and |
|
|
data. If metainfo or data is not None, the new result will overwrite it |
|
|
with the input value. |
|
|
|
|
|
Args: |
|
|
metainfo (dict, optional): A dict contains the meta information |
|
|
of image, such as ``img_shape``, ``scale_factor``, etc. |
|
|
Defaults to None. |
|
|
kwargs (dict): A dict contains annotations of image or |
|
|
model predictions. |
|
|
|
|
|
Returns: |
|
|
BaseDataElement: A new data element with same type. |
|
|
""" |
|
|
new_data = self.__class__() |
|
|
|
|
|
if metainfo is not None: |
|
|
new_data.set_metainfo(metainfo) |
|
|
else: |
|
|
new_data.set_metainfo(dict(self.metainfo_items())) |
|
|
if kwargs: |
|
|
new_data.set_data(kwargs) |
|
|
else: |
|
|
new_data.set_data(dict(self.items())) |
|
|
return new_data |
|
|
|
|
|
def clone(self): |
|
|
"""Deep copy the current data element. |
|
|
|
|
|
Returns: |
|
|
BaseDataElement: The copy of current data element. |
|
|
""" |
|
|
clone_data = self.__class__() |
|
|
clone_data.set_metainfo(dict(self.metainfo_items())) |
|
|
clone_data.set_data(dict(self.items())) |
|
|
return clone_data |
|
|
|
|
|
def keys(self) -> list: |
|
|
""" |
|
|
Returns: |
|
|
list: Contains all keys in data_fields. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private_keys = { |
|
|
'_' + key |
|
|
for key in self._data_fields |
|
|
if isinstance(getattr(type(self), key, None), property) |
|
|
} |
|
|
return list(self._data_fields - private_keys) |
|
|
|
|
|
def metainfo_keys(self) -> list: |
|
|
""" |
|
|
Returns: |
|
|
list: Contains all keys in metainfo_fields. |
|
|
""" |
|
|
return list(self._metainfo_fields) |
|
|
|
|
|
def values(self) -> list: |
|
|
""" |
|
|
Returns: |
|
|
list: Contains all values in data. |
|
|
""" |
|
|
return [getattr(self, k) for k in self.keys()] |
|
|
|
|
|
def metainfo_values(self) -> list: |
|
|
""" |
|
|
Returns: |
|
|
list: Contains all values in metainfo. |
|
|
""" |
|
|
return [getattr(self, k) for k in self.metainfo_keys()] |
|
|
|
|
|
def all_keys(self) -> list: |
|
|
""" |
|
|
Returns: |
|
|
list: Contains all keys in metainfo and data. |
|
|
""" |
|
|
return self.metainfo_keys() + self.keys() |
|
|
|
|
|
def all_values(self) -> list: |
|
|
""" |
|
|
Returns: |
|
|
list: Contains all values in metainfo and data. |
|
|
""" |
|
|
return self.metainfo_values() + self.values() |
|
|
|
|
|
def all_items(self) -> Iterator[Tuple[str, Any]]: |
|
|
""" |
|
|
Returns: |
|
|
iterator: An iterator object whose element is (key, value) tuple |
|
|
pairs for ``metainfo`` and ``data``. |
|
|
""" |
|
|
for k in self.all_keys(): |
|
|
yield (k, getattr(self, k)) |
|
|
|
|
|
def items(self) -> Iterator[Tuple[str, Any]]: |
|
|
""" |
|
|
Returns: |
|
|
iterator: An iterator object whose element is (key, value) tuple |
|
|
pairs for ``data``. |
|
|
""" |
|
|
for k in self.keys(): |
|
|
yield (k, getattr(self, k)) |
|
|
|
|
|
def metainfo_items(self) -> Iterator[Tuple[str, Any]]: |
|
|
""" |
|
|
Returns: |
|
|
iterator: An iterator object whose element is (key, value) tuple |
|
|
pairs for ``metainfo``. |
|
|
""" |
|
|
for k in self.metainfo_keys(): |
|
|
yield (k, getattr(self, k)) |
|
|
|
|
|
@property |
|
|
def metainfo(self) -> dict: |
|
|
"""dict: A dict contains metainfo of current data element.""" |
|
|
return dict(self.metainfo_items()) |
|
|
|
|
|
def __setattr__(self, name: str, value: Any): |
|
|
"""setattr is only used to set data.""" |
|
|
if name in ('_metainfo_fields', '_data_fields'): |
|
|
if not hasattr(self, name): |
|
|
super().__setattr__(name, value) |
|
|
else: |
|
|
raise AttributeError(f'{name} has been used as a ' |
|
|
'private attribute, which is immutable.') |
|
|
else: |
|
|
self.set_field( |
|
|
name=name, value=value, field_type='data', dtype=None) |
|
|
|
|
|
def __delattr__(self, item: str): |
|
|
"""Delete the item in dataelement. |
|
|
|
|
|
Args: |
|
|
item (str): The key to delete. |
|
|
""" |
|
|
if item in ('_metainfo_fields', '_data_fields'): |
|
|
raise AttributeError(f'{item} has been used as a ' |
|
|
'private attribute, which is immutable.') |
|
|
super().__delattr__(item) |
|
|
if item in self._metainfo_fields: |
|
|
self._metainfo_fields.remove(item) |
|
|
elif item in self._data_fields: |
|
|
self._data_fields.remove(item) |
|
|
|
|
|
|
|
|
__delitem__ = __delattr__ |
|
|
|
|
|
def get(self, key, default=None) -> Any: |
|
|
"""Get property in data and metainfo as the same as python.""" |
|
|
|
|
|
|
|
|
return getattr(self, key, default) |
|
|
|
|
|
def pop(self, *args) -> Any: |
|
|
"""Pop property in data and metainfo as the same as python.""" |
|
|
assert len(args) < 3, '``pop`` get more than 2 arguments' |
|
|
name = args[0] |
|
|
if name in self._metainfo_fields: |
|
|
self._metainfo_fields.remove(args[0]) |
|
|
return self.__dict__.pop(*args) |
|
|
|
|
|
elif name in self._data_fields: |
|
|
self._data_fields.remove(args[0]) |
|
|
return self.__dict__.pop(*args) |
|
|
|
|
|
|
|
|
elif len(args) == 2: |
|
|
return args[1] |
|
|
else: |
|
|
|
|
|
|
|
|
raise KeyError(f'{args[0]} is not contained in metainfo or data') |
|
|
|
|
|
def __contains__(self, item: str) -> bool: |
|
|
"""Whether the item is in dataelement. |
|
|
|
|
|
Args: |
|
|
item (str): The key to inquire. |
|
|
""" |
|
|
return item in self._data_fields or item in self._metainfo_fields |
|
|
|
|
|
def set_field(self, |
|
|
value: Any, |
|
|
name: str, |
|
|
dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, |
|
|
field_type: str = 'data') -> None: |
|
|
"""Special method for set union field, used as property.setter |
|
|
functions.""" |
|
|
assert field_type in ['metainfo', 'data'] |
|
|
if dtype is not None: |
|
|
assert isinstance( |
|
|
value, |
|
|
dtype), f'{value} should be a {dtype} but got {type(value)}' |
|
|
|
|
|
if field_type == 'metainfo': |
|
|
if name in self._data_fields: |
|
|
raise AttributeError( |
|
|
f'Cannot set {name} to be a field of metainfo ' |
|
|
f'because {name} is already a data field') |
|
|
self._metainfo_fields.add(name) |
|
|
else: |
|
|
if name in self._metainfo_fields: |
|
|
raise AttributeError( |
|
|
f'Cannot set {name} to be a field of data ' |
|
|
f'because {name} is already a metainfo field') |
|
|
self._data_fields.add(name) |
|
|
super().__setattr__(name, value) |
|
|
|
|
|
|
|
|
def to(self, *args, **kwargs) -> 'BaseDataElement': |
|
|
"""Apply same name function to all tensors in data_fields.""" |
|
|
new_data = self.new() |
|
|
for k, v in self.items(): |
|
|
if hasattr(v, 'to'): |
|
|
v = v.to(*args, **kwargs) |
|
|
data = {k: v} |
|
|
new_data.set_data(data) |
|
|
return new_data |
|
|
|
|
|
|
|
|
def cpu(self) -> 'BaseDataElement': |
|
|
"""Convert all tensors to CPU in data.""" |
|
|
new_data = self.new() |
|
|
for k, v in self.items(): |
|
|
if isinstance(v, (torch.Tensor, BaseDataElement)): |
|
|
v = v.cpu() |
|
|
data = {k: v} |
|
|
new_data.set_data(data) |
|
|
return new_data |
|
|
|
|
|
|
|
|
def cuda(self) -> 'BaseDataElement': |
|
|
"""Convert all tensors to GPU in data.""" |
|
|
new_data = self.new() |
|
|
for k, v in self.items(): |
|
|
if isinstance(v, (torch.Tensor, BaseDataElement)): |
|
|
v = v.cuda() |
|
|
data = {k: v} |
|
|
new_data.set_data(data) |
|
|
return new_data |
|
|
|
|
|
|
|
|
def npu(self) -> 'BaseDataElement': |
|
|
"""Convert all tensors to NPU in data.""" |
|
|
new_data = self.new() |
|
|
for k, v in self.items(): |
|
|
if isinstance(v, (torch.Tensor, BaseDataElement)): |
|
|
v = v.npu() |
|
|
data = {k: v} |
|
|
new_data.set_data(data) |
|
|
return new_data |
|
|
|
|
|
def mlu(self) -> 'BaseDataElement': |
|
|
"""Convert all tensors to MLU in data.""" |
|
|
new_data = self.new() |
|
|
for k, v in self.items(): |
|
|
if isinstance(v, (torch.Tensor, BaseDataElement)): |
|
|
v = v.mlu() |
|
|
data = {k: v} |
|
|
new_data.set_data(data) |
|
|
return new_data |
|
|
|
|
|
|
|
|
def detach(self) -> 'BaseDataElement': |
|
|
"""Detach all tensors in data.""" |
|
|
new_data = self.new() |
|
|
for k, v in self.items(): |
|
|
if isinstance(v, (torch.Tensor, BaseDataElement)): |
|
|
v = v.detach() |
|
|
data = {k: v} |
|
|
new_data.set_data(data) |
|
|
return new_data |
|
|
|
|
|
|
|
|
def numpy(self) -> 'BaseDataElement': |
|
|
"""Convert all tensors to np.ndarray in data.""" |
|
|
new_data = self.new() |
|
|
for k, v in self.items(): |
|
|
if isinstance(v, (torch.Tensor, BaseDataElement)): |
|
|
v = v.detach().cpu().numpy() |
|
|
data = {k: v} |
|
|
new_data.set_data(data) |
|
|
return new_data |
|
|
|
|
|
def to_tensor(self) -> 'BaseDataElement': |
|
|
"""Convert all np.ndarray to tensor in data.""" |
|
|
new_data = self.new() |
|
|
for k, v in self.items(): |
|
|
data = {} |
|
|
if isinstance(v, np.ndarray): |
|
|
v = torch.from_numpy(v) |
|
|
data[k] = v |
|
|
elif isinstance(v, BaseDataElement): |
|
|
v = v.to_tensor() |
|
|
data[k] = v |
|
|
new_data.set_data(data) |
|
|
return new_data |
|
|
|
|
|
def to_dict(self) -> dict: |
|
|
"""Convert BaseDataElement to dict.""" |
|
|
return { |
|
|
k: v.to_dict() if isinstance(v, BaseDataElement) else v |
|
|
for k, v in self.all_items() |
|
|
} |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
"""Represent the object.""" |
|
|
|
|
|
def _addindent(s_: str, num_spaces: int) -> str: |
|
|
"""This func is modified from `pytorch` https://github.com/pytorch/ |
|
|
pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu |
|
|
les/module.py#L29. |
|
|
|
|
|
Args: |
|
|
s_ (str): The string to add spaces. |
|
|
num_spaces (int): The num of space to add. |
|
|
|
|
|
Returns: |
|
|
str: The string after add indent. |
|
|
""" |
|
|
s = s_.split('\n') |
|
|
|
|
|
if len(s) == 1: |
|
|
return s_ |
|
|
first = s.pop(0) |
|
|
s = [(num_spaces * ' ') + line for line in s] |
|
|
s = '\n'.join(s) |
|
|
s = first + '\n' + s |
|
|
return s |
|
|
|
|
|
def dump(obj: Any) -> str: |
|
|
"""Represent the object. |
|
|
|
|
|
Args: |
|
|
obj (Any): The obj to represent. |
|
|
|
|
|
Returns: |
|
|
str: The represented str. |
|
|
""" |
|
|
_repr = '' |
|
|
if isinstance(obj, dict): |
|
|
for k, v in obj.items(): |
|
|
_repr += f'\n{k}: {_addindent(dump(v), 4)}' |
|
|
elif isinstance(obj, BaseDataElement): |
|
|
_repr += '\n\n META INFORMATION' |
|
|
metainfo_items = dict(obj.metainfo_items()) |
|
|
_repr += _addindent(dump(metainfo_items), 4) |
|
|
_repr += '\n\n DATA FIELDS' |
|
|
items = dict(obj.items()) |
|
|
_repr += _addindent(dump(items), 4) |
|
|
classname = obj.__class__.__name__ |
|
|
_repr = f'<{classname}({_repr}\n) at {hex(id(obj))}>' |
|
|
else: |
|
|
_repr += repr(obj) |
|
|
return _repr |
|
|
|
|
|
return dump(self) |
|
|
|