File size: 6,019 Bytes
b13b124 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import glob
import os
from os.path import dirname, exists, isdir, join, relpath
from mmcv import Config
from torch import nn
from mmseg.models import build_segmentor
def _get_config_directory():
"""Find the predefined segmentor config directory."""
try:
# Assume we are running in the source mmsegmentation repo
repo_dpath = dirname(dirname(__file__))
except NameError:
# For IPython development when this __file__ is not defined
import mmseg
repo_dpath = dirname(dirname(mmseg.__file__))
config_dpath = join(repo_dpath, 'configs')
if not exists(config_dpath):
raise Exception('Cannot find config path')
return config_dpath
def test_config_build_segmentor():
"""Test that all segmentation models defined in the configs can be
initialized."""
config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath))
config_fpaths = []
# one config each sub folder
for sub_folder in os.listdir(config_dpath):
if isdir(sub_folder):
config_fpaths.append(
list(glob.glob(join(config_dpath, sub_folder, '*.py')))[0])
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
config_names = [relpath(p, config_dpath) for p in config_fpaths]
print('Using {} config files'.format(len(config_names)))
for config_fname in config_names:
config_fpath = join(config_dpath, config_fname)
config_mod = Config.fromfile(config_fpath)
config_mod.model
print('Building segmentor, config_fpath = {!r}'.format(config_fpath))
# Remove pretrained keys to allow for testing in an offline environment
if 'pretrained' in config_mod.model:
config_mod.model['pretrained'] = None
print('building {}'.format(config_fname))
segmentor = build_segmentor(config_mod.model)
assert segmentor is not None
head_config = config_mod.model['decode_head']
_check_decode_head(head_config, segmentor.decode_head)
def test_config_data_pipeline():
"""Test whether the data pipeline is valid and can process corner cases.
CommandLine:
xdoctest -m tests/test_config.py test_config_build_data_pipeline
"""
from mmcv import Config
from mmseg.datasets.pipelines import Compose
import numpy as np
config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath))
import glob
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
config_names = [relpath(p, config_dpath) for p in config_fpaths]
print('Using {} config files'.format(len(config_names)))
for config_fname in config_names:
config_fpath = join(config_dpath, config_fname)
print(
'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
config_mod = Config.fromfile(config_fpath)
# remove loading pipeline
load_img_pipeline = config_mod.train_pipeline.pop(0)
to_float32 = load_img_pipeline.get('to_float32', False)
config_mod.train_pipeline.pop(0)
config_mod.test_pipeline.pop(0)
train_pipeline = Compose(config_mod.train_pipeline)
test_pipeline = Compose(config_mod.test_pipeline)
img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
if to_float32:
img = img.astype(np.float32)
seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_semantic_seg=seg)
results['seg_fields'] = ['gt_semantic_seg']
print('Test training data pipeline: \n{!r}'.format(train_pipeline))
output_results = train_pipeline(results)
assert output_results is not None
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
)
print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
output_results = test_pipeline(results)
assert output_results is not None
def _check_decode_head(decode_head_cfg, decode_head):
if isinstance(decode_head_cfg, list):
assert isinstance(decode_head, nn.ModuleList)
assert len(decode_head_cfg) == len(decode_head)
num_heads = len(decode_head)
for i in range(num_heads):
_check_decode_head(decode_head_cfg[i], decode_head[i])
return
# check consistency between head_config and roi_head
assert decode_head_cfg['type'] == decode_head.__class__.__name__
assert decode_head_cfg['type'] == decode_head.__class__.__name__
in_channels = decode_head_cfg.in_channels
input_transform = decode_head.input_transform
assert input_transform in ['resize_concat', 'multiple_select', None]
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(decode_head.in_index, (list, tuple))
assert len(in_channels) == len(decode_head.in_index)
elif input_transform == 'resize_concat':
assert sum(in_channels) == decode_head.in_channels
else:
assert isinstance(in_channels, int)
assert in_channels == decode_head.in_channels
assert isinstance(decode_head.in_index, int)
if decode_head_cfg['type'] == 'PointHead':
assert decode_head_cfg.channels+decode_head_cfg.num_classes == \
decode_head.fc_seg.in_channels
assert decode_head.fc_seg.out_channels == decode_head_cfg.num_classes
else:
assert decode_head_cfg.channels == decode_head.conv_seg.in_channels
assert decode_head.conv_seg.out_channels == decode_head_cfg.num_classes
|