| | |
| | import os.path as osp |
| | from copy import deepcopy |
| |
|
| | import mmcv |
| | import numpy as np |
| | import torch |
| |
|
| | from mmdet.utils import split_batch |
| |
|
| |
|
| | def test_split_batch(): |
| | img_root = osp.join(osp.dirname(__file__), '../data/color.jpg') |
| | img = mmcv.imread(img_root, 'color') |
| | h, w, _ = img.shape |
| | gt_bboxes = np.array([[0.2 * w, 0.2 * h, 0.4 * w, 0.4 * h], |
| | [0.6 * w, 0.6 * h, 0.8 * w, 0.8 * h]], |
| | dtype=np.float32) |
| | gt_lables = np.ones(gt_bboxes.shape[0], dtype=np.int64) |
| |
|
| | img = torch.tensor(img).permute(2, 0, 1) |
| | meta = dict() |
| | meta['filename'] = img_root |
| | meta['ori_shape'] = img.shape |
| | meta['img_shape'] = img.shape |
| | meta['img_norm_cfg'] = { |
| | 'mean': np.array([103.53, 116.28, 123.675], dtype=np.float32), |
| | 'std': np.array([1., 1., 1.], dtype=np.float32), |
| | 'to_rgb': False |
| | } |
| | meta['pad_shape'] = img.shape |
| | |
| | |
| | |
| | |
| | |
| | imgs = img.unsqueeze(0).repeat(9, 1, 1, 1) |
| | img_metas = [] |
| | tags = [ |
| | 'sup', 'unsup_teacher', 'unsup_student', 'unsup_teacher', |
| | 'unsup_student', 'unsup_teacher', 'unsup_student', 'unsup_teacher', |
| | 'unsup_student' |
| | ] |
| | for tag in tags: |
| | img_meta = deepcopy(meta) |
| | if tag == 'sup': |
| | img_meta['scale_factor'] = [0.5, 0.5, 0.5, 0.5] |
| | img_meta['tag'] = 'sup' |
| | elif tag == 'unsup_teacher': |
| | img_meta['scale_factor'] = [1.0, 1.0, 1.0, 1.0] |
| | img_meta['tag'] = 'unsup_teacher' |
| | elif tag == 'unsup_student': |
| | img_meta['scale_factor'] = [2.0, 2.0, 2.0, 2.0] |
| | img_meta['tag'] = 'unsup_student' |
| | else: |
| | continue |
| | img_metas.append(img_meta) |
| | kwargs = dict() |
| | kwargs['gt_bboxes'] = [torch.tensor(gt_bboxes)] + [torch.zeros(0, 4)] * 8 |
| | kwargs['gt_lables'] = [torch.tensor(gt_lables)] + [torch.zeros(0, )] * 8 |
| | data_groups = split_batch(imgs, img_metas, kwargs) |
| | assert set(data_groups.keys()) == set(tags) |
| | assert data_groups['sup']['img'].shape == (1, 3, h, w) |
| | assert data_groups['unsup_teacher']['img'].shape == (4, 3, h, w) |
| | assert data_groups['unsup_student']['img'].shape == (4, 3, h, w) |
| | |
| | assert data_groups['sup']['img_metas'][0]['scale_factor'] == [ |
| | 0.5, 0.5, 0.5, 0.5 |
| | ] |
| | |
| | assert data_groups['unsup_teacher']['img_metas'][0]['scale_factor'] == [ |
| | 1.0, 1.0, 1.0, 1.0 |
| | ] |
| | assert data_groups['unsup_teacher']['img_metas'][1]['scale_factor'] == [ |
| | 1.0, 1.0, 1.0, 1.0 |
| | ] |
| | assert data_groups['unsup_teacher']['img_metas'][2]['scale_factor'] == [ |
| | 1.0, 1.0, 1.0, 1.0 |
| | ] |
| | assert data_groups['unsup_teacher']['img_metas'][3]['scale_factor'] == [ |
| | 1.0, 1.0, 1.0, 1.0 |
| | ] |
| | |
| | assert data_groups['unsup_student']['img_metas'][0]['scale_factor'] == [ |
| | 2.0, 2.0, 2.0, 2.0 |
| | ] |
| | assert data_groups['unsup_student']['img_metas'][1]['scale_factor'] == [ |
| | 2.0, 2.0, 2.0, 2.0 |
| | ] |
| | assert data_groups['unsup_student']['img_metas'][2]['scale_factor'] == [ |
| | 2.0, 2.0, 2.0, 2.0 |
| | ] |
| | assert data_groups['unsup_student']['img_metas'][3]['scale_factor'] == [ |
| | 2.0, 2.0, 2.0, 2.0 |
| | ] |
| |
|