Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from unittest import TestCase | |
| import numpy as np | |
| import torch | |
| from mmengine.structures import InstanceData | |
| from mmocr.structures import KIEDataSample | |
| class TestTextDetDataSample(TestCase): | |
| def _equal(self, a, b): | |
| if isinstance(a, (torch.Tensor, np.ndarray)): | |
| return (a == b).all() | |
| else: | |
| return a == b | |
| def test_init(self): | |
| meta_info = dict( | |
| img_size=[256, 256], | |
| scale_factor=np.array([1.5, 1.5]), | |
| img_shape=torch.rand(4)) | |
| kie_data_sample = KIEDataSample(metainfo=meta_info) | |
| assert 'img_size' in kie_data_sample | |
| self.assertListEqual(kie_data_sample.img_size, [256, 256]) | |
| self.assertListEqual(kie_data_sample.get('img_size'), [256, 256]) | |
| def test_setter(self): | |
| kie_data_sample = KIEDataSample() | |
| # test gt_instances | |
| gt_instances_data = dict( | |
| bboxes=torch.rand(4, 4), | |
| labels=torch.rand(4), | |
| texts=['t1', 't2', 't3', 't4'], | |
| relations=torch.rand(4, 4), | |
| edge_labels=torch.randint(0, 4, (4, ))) | |
| gt_instances = InstanceData(**gt_instances_data) | |
| kie_data_sample.gt_instances = gt_instances | |
| self.assertIn('gt_instances', kie_data_sample) | |
| self.assertTrue( | |
| self._equal(kie_data_sample.gt_instances.bboxes, | |
| gt_instances_data['bboxes'])) | |
| self.assertTrue( | |
| self._equal(kie_data_sample.gt_instances.labels, | |
| gt_instances_data['labels'])) | |
| self.assertTrue( | |
| self._equal(kie_data_sample.gt_instances.texts, | |
| gt_instances_data['texts'])) | |
| self.assertTrue( | |
| self._equal(kie_data_sample.gt_instances.relations, | |
| gt_instances_data['relations'])) | |
| self.assertTrue( | |
| self._equal(kie_data_sample.gt_instances.edge_labels, | |
| gt_instances_data['edge_labels'])) | |
| # test pred_instances | |
| pred_instances_data = dict( | |
| bboxes=torch.rand(4, 4), | |
| labels=torch.rand(4), | |
| texts=['t1', 't2', 't3', 't4'], | |
| relations=torch.rand(4, 4), | |
| edge_labels=torch.randint(0, 4, (4, ))) | |
| pred_instances = InstanceData(**pred_instances_data) | |
| kie_data_sample.pred_instances = pred_instances | |
| assert 'pred_instances' in kie_data_sample | |
| assert self._equal(kie_data_sample.pred_instances.bboxes, | |
| pred_instances_data['bboxes']) | |
| assert self._equal(kie_data_sample.pred_instances.labels, | |
| pred_instances_data['labels']) | |
| self.assertTrue( | |
| self._equal(kie_data_sample.pred_instances.texts, | |
| pred_instances_data['texts'])) | |
| self.assertTrue( | |
| self._equal(kie_data_sample.pred_instances.relations, | |
| pred_instances_data['relations'])) | |
| self.assertTrue( | |
| self._equal(kie_data_sample.pred_instances.edge_labels, | |
| pred_instances_data['edge_labels'])) | |
| # test type error | |
| with self.assertRaises(AssertionError): | |
| kie_data_sample.gt_instances = torch.rand(2, 4) | |
| with self.assertRaises(AssertionError): | |
| kie_data_sample.pred_instances = torch.rand(2, 4) | |
| def test_deleter(self): | |
| gt_instances_data = dict( | |
| bboxes=torch.rand(4, 4), | |
| labels=torch.rand(4), | |
| ) | |
| kie_data_sample = KIEDataSample() | |
| gt_instances = InstanceData(data=gt_instances_data) | |
| kie_data_sample.gt_instances = gt_instances | |
| assert 'gt_instances' in kie_data_sample | |
| del kie_data_sample.gt_instances | |
| assert 'gt_instances' not in kie_data_sample | |
| kie_data_sample.pred_instances = gt_instances | |
| assert 'pred_instances' in kie_data_sample | |
| del kie_data_sample.pred_instances | |
| assert 'pred_instances' not in kie_data_sample | |