Spaces:
Runtime error
Runtime error
File size: 2,275 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import unittest.mock as mock
import numpy as np
import pytest
from mmocr.datasets.pipelines import (OneOfWrapper, RandomWrapper,
TorchVisionWrapper)
from mmocr.datasets.pipelines.transforms import ColorJitter
def test_torchvision_wrapper():
x = {'img': np.ones((128, 100, 3), dtype=np.uint8)}
# object not found error
with pytest.raises(Exception):
TorchVisionWrapper(op='NonExist')
with pytest.raises(TypeError):
TorchVisionWrapper()
f = TorchVisionWrapper('Grayscale')
with pytest.raises(AssertionError):
f({})
results = f(x)
assert results['img'].shape == (128, 100)
assert results['img_shape'] == (128, 100)
@mock.patch('random.choice')
def test_oneof(rand_choice):
color_jitter = dict(type='TorchVisionWrapper', op='ColorJitter')
gray_scale = dict(type='TorchVisionWrapper', op='Grayscale')
x = {'img': np.random.randint(0, 256, size=(128, 100, 3), dtype=np.uint8)}
f = OneOfWrapper([color_jitter, gray_scale])
# Use color_jitter at the first call
rand_choice.side_effect = lambda x: x[0]
results = f(x)
assert results['img'].shape == (128, 100, 3)
# Use gray_scale at the second call
rand_choice.side_effect = lambda x: x[1]
results = f(x)
assert results['img'].shape == (128, 100)
# Passing object
f = OneOfWrapper([ColorJitter(), gray_scale])
# Use color_jitter at the first call
results = f(x)
assert results['img'].shape == (128, 100)
# Test invalid inputs
with pytest.raises(AssertionError):
f = OneOfWrapper(None)
with pytest.raises(AssertionError):
f = OneOfWrapper([])
with pytest.raises(AssertionError):
f = OneOfWrapper({})
@mock.patch('numpy.random.uniform')
def test_runwithprob(np_random_uniform):
np_random_uniform.side_effect = [0.1, 0.9]
f = RandomWrapper([dict(type='TorchVisionWrapper', op='Grayscale')], 0.5)
img = np.random.randint(0, 256, size=(128, 100, 3), dtype=np.uint8)
results = f({'img': copy.deepcopy(img)})
assert results['img'].shape == (128, 100)
results = f({'img': copy.deepcopy(img)})
assert results['img'].shape == (128, 100, 3)
|