Image Segmentation
Inference Endpoints
test2 / tests /test_models /
mccaly's picture
Upload 660 files
history blame
No virus
7.43 kB
"""pytest tests/"""
import copy
from os.path import dirname, exists, join
from unittest.mock import patch
import numpy as np
import pytest
import torch
import torch.nn as nn
from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm
def _demo_mm_inputs(input_shape=(2, 3, 8, 16), num_classes=10):
"""Create a superset of inputs needed to run test or train batches.
input_shape (tuple):
input batch dimensions
num_classes (int):
number of semantic classes
(N, C, H, W) = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
segs = rng.randint(
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
img_metas = [{
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'filename': '<demo>.png',
'scale_factor': 1.0,
'flip': False,
'flip_direction': 'horizontal'
} for _ in range(N)]
mm_inputs = {
'imgs': torch.FloatTensor(imgs),
'img_metas': img_metas,
'gt_semantic_seg': torch.LongTensor(segs)
return mm_inputs
def _get_config_directory():
"""Find the predefined segmentor config directory."""
# Assume we are running in the source mmsegmentation repo
repo_dpath = dirname(dirname(dirname(__file__)))
except NameError:
# For IPython development when this __file__ is not defined
import mmseg
repo_dpath = dirname(dirname(dirname(mmseg.__file__)))
config_dpath = join(repo_dpath, 'configs')
if not exists(config_dpath):
raise Exception('Cannot find config path')
return config_dpath
def _get_config_module(fname):
"""Load a configuration as a python module."""
from mmcv import Config
config_dpath = _get_config_directory()
config_fpath = join(config_dpath, fname)
config_mod = Config.fromfile(config_fpath)
return config_mod
def _get_segmentor_cfg(fname):
"""Grab configs necessary to create a segmentor.
These are deep copied to allow for safe modification of parameters without
influencing other tests.
config = _get_config_module(fname)
model = copy.deepcopy(config.model)
return model
def test_pspnet_forward():
def test_fcn_forward():
def test_deeplabv3_forward():
def test_deeplabv3plus_forward():
def test_gcnet_forward():
def test_ann_forward():
def test_ccnet_forward():
if not torch.cuda.is_available():
pytest.skip('CCNet requires CUDA')
def test_danet_forward():
def test_nonlocal_net_forward():
def test_upernet_forward():
def test_hrnet_forward():
def test_ocrnet_forward():
def test_psanet_forward():
def test_encnet_forward():
def test_sem_fpn_forward():
def test_point_rend_forward():
def test_mobilenet_v2_forward():
def test_dnlnet_forward():
def test_emanet_forward():
def get_world_size(process_group):
return 1
def _check_input_dim(self, inputs):
def _convert_batchnorm(module):
module_output = module
if isinstance(module, SyncBatchNorm):
# to be consistent with SyncBN, we hack dim check function in BN
module_output = _BatchNorm(module.num_features, module.eps,
module.momentum, module.affine,
if module.affine: = =
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
@patch('torch.distributed.get_world_size', get_world_size)
def _test_encoder_decoder_forward(cfg_file):
model = _get_segmentor_cfg(cfg_file)
model['pretrained'] = None
model['test_cfg']['mode'] = 'whole'
from mmseg.models import build_segmentor
segmentor = build_segmentor(model)
if isinstance(segmentor.decode_head, nn.ModuleList):
num_classes = segmentor.decode_head[-1].num_classes
num_classes = segmentor.decode_head.num_classes
# batch_size=2 for BatchNorm
input_shape = (2, 3, 32, 32)
mm_inputs = _demo_mm_inputs(input_shape, num_classes=num_classes)
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
gt_semantic_seg = mm_inputs['gt_semantic_seg']
# convert to cuda Tensor if applicable
if torch.cuda.is_available():
segmentor = segmentor.cuda()
imgs = imgs.cuda()
gt_semantic_seg = gt_semantic_seg.cuda()
segmentor = _convert_batchnorm(segmentor)
# Test forward train
losses = segmentor.forward(
imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True)
assert isinstance(losses, dict)
# Test forward test
with torch.no_grad():
# pack into lists
img_list = [img[None, :] for img in imgs]
img_meta_list = [[img_meta] for img_meta in img_metas]
segmentor.forward(img_list, img_meta_list, return_loss=False)