File size: 2,310 Bytes
d7a991a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv import Config

from mmpose.datasets.builder import build_dataset


def test_concat_dataset():
    # build COCO-like dataset config
    dataset_info = Config.fromfile(
        'configs/_base_/datasets/coco.py').dataset_info

    channel_cfg = dict(
        num_output_channels=17,
        dataset_joints=17,
        dataset_channel=[
            [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
        ],
        inference_channel=[
            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
        ])

    data_cfg = dict(
        image_size=[192, 256],
        heatmap_size=[48, 64],
        num_output_channels=channel_cfg['num_output_channels'],
        num_joints=channel_cfg['dataset_joints'],
        dataset_channel=channel_cfg['dataset_channel'],
        inference_channel=channel_cfg['inference_channel'],
        soft_nms=False,
        nms_thr=1.0,
        oks_thr=0.9,
        vis_thr=0.2,
        use_gt_bbox=True,
        det_bbox_thr=0.0,
        bbox_file='tests/data/coco/test_coco_det_AP_H_56.json',
    )

    dataset_cfg = dict(
        type='TopDownCocoDataset',
        ann_file='tests/data/coco/test_coco.json',
        img_prefix='tests/data/coco/',
        data_cfg=data_cfg,
        pipeline=[],
        dataset_info=dataset_info)

    dataset = build_dataset(dataset_cfg)

    # Case 1: build ConcatDataset explicitly
    concat_dataset_cfg = dict(
        type='ConcatDataset', datasets=[dataset_cfg, dataset_cfg])
    concat_dataset = build_dataset(concat_dataset_cfg)
    assert len(concat_dataset) == 2 * len(dataset)

    # Case 2: build ConcatDataset from cfg sequence
    concat_dataset = build_dataset([dataset_cfg, dataset_cfg])
    assert len(concat_dataset) == 2 * len(dataset)

    # Case 3: build ConcatDataset from ann_file sequence
    concat_dataset_cfg = dataset_cfg.copy()
    for key in ['ann_file', 'type', 'img_prefix', 'dataset_info']:
        val = concat_dataset_cfg[key]
        concat_dataset_cfg[key] = [val] * 2
    for key in ['num_joints', 'dataset_channel']:
        val = concat_dataset_cfg['data_cfg'][key]
        concat_dataset_cfg['data_cfg'][key] = [val] * 2
    concat_dataset = build_dataset(concat_dataset_cfg)
    assert len(concat_dataset) == 2 * len(dataset)