Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import numpy as np | |
import torch | |
from mmdet.utils.util_mixins import NiceRepr | |
class GeneralData(NiceRepr): | |
"""A general data structure of OpenMMlab. | |
A data structure that stores the meta information, | |
the annotations of the images or the model predictions, | |
which can be used in communication between components. | |
The attributes in `GeneralData` are divided into two parts, | |
the `meta_info_fields` and the `data_fields` respectively. | |
- `meta_info_fields`: Usually contains the | |
information about the image such as filename, | |
image_shape, pad_shape, etc. All attributes in | |
it are immutable once set, | |
but the user can add new meta information with | |
`set_meta_info` function, all information can be accessed | |
with methods `meta_info_keys`, `meta_info_values`, | |
`meta_info_items`. | |
- `data_fields`: Annotations or model predictions are | |
stored. The attributes can be accessed or modified by | |
dict-like or object-like operations, such as | |
`.` , `[]`, `in`, `del`, `pop(str)` `get(str)`, `keys()`, | |
`values()`, `items()`. Users can also apply tensor-like methods | |
to all obj:`torch.Tensor` in the `data_fileds`, | |
such as `.cuda()`, `.cpu()`, `.numpy()`, `device`, `.to()` | |
`.detach()`, `.numpy()` | |
Args: | |
meta_info (dict, optional): A dict contains the meta information | |
of single image. such as `img_shape`, `scale_factor`, etc. | |
Default: None. | |
data (dict, optional): A dict contains annotations of single image or | |
model predictions. Default: None. | |
Examples: | |
>>> from mmdet.core import GeneralData | |
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) | |
>>> instance_data = GeneralData(meta_info=img_meta) | |
>>> img_shape in instance_data | |
True | |
>>> instance_data.det_labels = torch.LongTensor([0, 1, 2, 3]) | |
>>> instance_data["det_scores"] = torch.Tensor([0.01, 0.1, 0.2, 0.3]) | |
>>> print(results) | |
<GeneralData( | |
META INFORMATION | |
img_shape: (800, 1196, 3) | |
pad_shape: (800, 1216, 3) | |
DATA FIELDS | |
shape of det_labels: torch.Size([4]) | |
shape of det_scores: torch.Size([4]) | |
) at 0x7f84acd10f90> | |
>>> instance_data.det_scores | |
tensor([0.0100, 0.1000, 0.2000, 0.3000]) | |
>>> instance_data.det_labels | |
tensor([0, 1, 2, 3]) | |
>>> instance_data['det_labels'] | |
tensor([0, 1, 2, 3]) | |
>>> 'det_labels' in instance_data | |
True | |
>>> instance_data.img_shape | |
(800, 1196, 3) | |
>>> 'det_scores' in instance_data | |
True | |
>>> del instance_data.det_scores | |
>>> 'det_scores' in instance_data | |
False | |
>>> det_labels = instance_data.pop('det_labels', None) | |
>>> det_labels | |
tensor([0, 1, 2, 3]) | |
>>> 'det_labels' in instance_data | |
>>> False | |
""" | |
def __init__(self, meta_info=None, data=None): | |
self._meta_info_fields = set() | |
self._data_fields = set() | |
if meta_info is not None: | |
self.set_meta_info(meta_info=meta_info) | |
if data is not None: | |
self.set_data(data) | |
def set_meta_info(self, meta_info): | |
"""Add meta information. | |
Args: | |
meta_info (dict): A dict contains the meta information | |
of image. such as `img_shape`, `scale_factor`, etc. | |
Default: None. | |
""" | |
assert isinstance(meta_info, | |
dict), f'meta should be a `dict` but get {meta_info}' | |
meta = copy.deepcopy(meta_info) | |
for k, v in meta.items(): | |
# should be consistent with original meta_info | |
if k in self._meta_info_fields: | |
ori_value = getattr(self, k) | |
if isinstance(ori_value, (torch.Tensor, np.ndarray)): | |
if (ori_value == v).all(): | |
continue | |
else: | |
raise KeyError( | |
f'img_meta_info {k} has been set as ' | |
f'{getattr(self, k)} before, which is immutable ') | |
elif ori_value == v: | |
continue | |
else: | |
raise KeyError( | |
f'img_meta_info {k} has been set as ' | |
f'{getattr(self, k)} before, which is immutable ') | |
else: | |
self._meta_info_fields.add(k) | |
self.__dict__[k] = v | |
def set_data(self, data): | |
"""Update a dict to `data_fields`. | |
Args: | |
data (dict): A dict contains annotations of image or | |
model predictions. Default: None. | |
""" | |
assert isinstance(data, | |
dict), f'meta should be a `dict` but get {data}' | |
for k, v in data.items(): | |
self.__setattr__(k, v) | |
def new(self, meta_info=None, data=None): | |
"""Return a new results with same image meta information. | |
Args: | |
meta_info (dict, optional): A dict contains the meta information | |
of image. such as `img_shape`, `scale_factor`, etc. | |
Default: None. | |
data (dict, optional): A dict contains annotations of image or | |
model predictions. Default: None. | |
""" | |
new_data = self.__class__() | |
new_data.set_meta_info(dict(self.meta_info_items())) | |
if meta_info is not None: | |
new_data.set_meta_info(meta_info) | |
if data is not None: | |
new_data.set_data(data) | |
return new_data | |
def keys(self): | |
""" | |
Returns: | |
list: Contains all keys in data_fields. | |
""" | |
return [key for key in self._data_fields] | |
def meta_info_keys(self): | |
""" | |
Returns: | |
list: Contains all keys in meta_info_fields. | |
""" | |
return [key for key in self._meta_info_fields] | |
def values(self): | |
""" | |
Returns: | |
list: Contains all values in data_fields. | |
""" | |
return [getattr(self, k) for k in self.keys()] | |
def meta_info_values(self): | |
""" | |
Returns: | |
list: Contains all values in meta_info_fields. | |
""" | |
return [getattr(self, k) for k in self.meta_info_keys()] | |
def items(self): | |
for k in self.keys(): | |
yield (k, getattr(self, k)) | |
def meta_info_items(self): | |
for k in self.meta_info_keys(): | |
yield (k, getattr(self, k)) | |
def __setattr__(self, name, val): | |
if name in ('_meta_info_fields', '_data_fields'): | |
if not hasattr(self, name): | |
super().__setattr__(name, val) | |
else: | |
raise AttributeError( | |
f'{name} has been used as a ' | |
f'private attribute, which is immutable. ') | |
else: | |
if name in self._meta_info_fields: | |
raise AttributeError(f'`{name}` is used in meta information,' | |
f'which is immutable') | |
self._data_fields.add(name) | |
super().__setattr__(name, val) | |
def __delattr__(self, item): | |
if item in ('_meta_info_fields', '_data_fields'): | |
raise AttributeError(f'{item} has been used as a ' | |
f'private attribute, which is immutable. ') | |
if item in self._meta_info_fields: | |
raise KeyError(f'{item} is used in meta information, ' | |
f'which is immutable.') | |
super().__delattr__(item) | |
if item in self._data_fields: | |
self._data_fields.remove(item) | |
# dict-like methods | |
__setitem__ = __setattr__ | |
__delitem__ = __delattr__ | |
def __getitem__(self, name): | |
return getattr(self, name) | |
def get(self, *args): | |
assert len(args) < 3, '`get` get more than 2 arguments' | |
return self.__dict__.get(*args) | |
def pop(self, *args): | |
assert len(args) < 3, '`pop` get more than 2 arguments' | |
name = args[0] | |
if name in self._meta_info_fields: | |
raise KeyError(f'{name} is a key in meta information, ' | |
f'which is immutable') | |
if args[0] in self._data_fields: | |
self._data_fields.remove(args[0]) | |
return self.__dict__.pop(*args) | |
# with default value | |
elif len(args) == 2: | |
return args[1] | |
else: | |
raise KeyError(f'{args[0]}') | |
def __contains__(self, item): | |
return item in self._data_fields or \ | |
item in self._meta_info_fields | |
# Tensor-like methods | |
def to(self, *args, **kwargs): | |
"""Apply same name function to all tensors in data_fields.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if hasattr(v, 'to'): | |
v = v.to(*args, **kwargs) | |
new_data[k] = v | |
return new_data | |
# Tensor-like methods | |
def cpu(self): | |
"""Apply same name function to all tensors in data_fields.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.cpu() | |
new_data[k] = v | |
return new_data | |
# Tensor-like methods | |
def npu(self): | |
"""Apply same name function to all tensors in data_fields.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.npu() | |
new_data[k] = v | |
return new_data | |
# Tensor-like methods | |
def mlu(self): | |
"""Apply same name function to all tensors in data_fields.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.mlu() | |
new_data[k] = v | |
return new_data | |
# Tensor-like methods | |
def cuda(self): | |
"""Apply same name function to all tensors in data_fields.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.cuda() | |
new_data[k] = v | |
return new_data | |
# Tensor-like methods | |
def detach(self): | |
"""Apply same name function to all tensors in data_fields.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.detach() | |
new_data[k] = v | |
return new_data | |
# Tensor-like methods | |
def numpy(self): | |
"""Apply same name function to all tensors in data_fields.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.detach().cpu().numpy() | |
new_data[k] = v | |
return new_data | |
def __nice__(self): | |
repr = '\n \n META INFORMATION \n' | |
for k, v in self.meta_info_items(): | |
repr += f'{k}: {v} \n' | |
repr += '\n DATA FIELDS \n' | |
for k, v in self.items(): | |
if isinstance(v, (torch.Tensor, np.ndarray)): | |
repr += f'shape of {k}: {v.shape} \n' | |
else: | |
repr += f'{k}: {v} \n' | |
return repr + '\n' | |