|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from kornia.augmentation.random_generator import ( |
|
|
center_crop_generator3d, |
|
|
random_affine_generator3d, |
|
|
random_crop_generator3d, |
|
|
random_motion_blur_generator3d, |
|
|
random_perspective_generator3d, |
|
|
random_rotation_generator3d, |
|
|
) |
|
|
from kornia.testing import assert_close |
|
|
|
|
|
|
|
|
class RandomGeneratorBaseTests: |
|
|
def test_valid_param_combinations(self, device, dtype): |
|
|
raise NotImplementedError |
|
|
|
|
|
def test_invalid_param_combinations(self, device, dtype): |
|
|
raise NotImplementedError |
|
|
|
|
|
def test_random_gen(self, device, dtype): |
|
|
raise NotImplementedError |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class TestRandomPerspectiveGen3D(RandomGeneratorBaseTests): |
|
|
@pytest.mark.parametrize('batch_size', [0, 1, 8]) |
|
|
@pytest.mark.parametrize('depth,height,width', [(200, 200, 200)]) |
|
|
@pytest.mark.parametrize('distortion_scale', [torch.tensor(0.0), torch.tensor(0.5), torch.tensor(1.0)]) |
|
|
@pytest.mark.parametrize('same_on_batch', [True, False]) |
|
|
def test_valid_param_combinations( |
|
|
self, depth, height, width, distortion_scale, batch_size, same_on_batch, device, dtype |
|
|
): |
|
|
random_perspective_generator3d( |
|
|
batch_size=batch_size, |
|
|
depth=depth, |
|
|
height=height, |
|
|
width=width, |
|
|
distortion_scale=distortion_scale.to(device=device, dtype=dtype), |
|
|
same_on_batch=same_on_batch, |
|
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
'depth,height,width,distortion_scale', |
|
|
[ |
|
|
|
|
|
(-100, 100, 100, torch.tensor(0.5)), |
|
|
(100, -100, 100, torch.tensor(0.5)), |
|
|
(100, 100, -100, torch.tensor(-0.5)), |
|
|
(100, 100, 100, torch.tensor(1.5)), |
|
|
(100, 100, 100, torch.tensor([0.0, 0.5])), |
|
|
], |
|
|
) |
|
|
def test_invalid_param_combinations(self, depth, height, width, distortion_scale, device, dtype): |
|
|
with pytest.raises(Exception): |
|
|
random_perspective_generator3d( |
|
|
batch_size=8, |
|
|
height=height, |
|
|
width=width, |
|
|
distortion_scale=distortion_scale.to(device=device, dtype=dtype), |
|
|
) |
|
|
|
|
|
def test_random_gen(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
batch_size = 2 |
|
|
res = random_perspective_generator3d(batch_size, 200, 200, 200, torch.tensor(0.5, device=device, dtype=dtype)) |
|
|
expected = dict( |
|
|
start_points=torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[0.0, 0.0, 0.0], |
|
|
[199.0, 0.0, 0.0], |
|
|
[199.0, 199.0, 0.0], |
|
|
[0.0, 199.0, 0.0], |
|
|
[0.0, 0.0, 199.0], |
|
|
[199.0, 0.0, 199.0], |
|
|
[199.0, 199.0, 199.0], |
|
|
[0.0, 199.0, 199.0], |
|
|
], |
|
|
[ |
|
|
[0.0, 0.0, 0.0], |
|
|
[199.0, 0.0, 0.0], |
|
|
[199.0, 199.0, 0.0], |
|
|
[0.0, 199.0, 0.0], |
|
|
[0.0, 0.0, 199.0], |
|
|
[199.0, 0.0, 199.0], |
|
|
[199.0, 199.0, 199.0], |
|
|
[0.0, 199.0, 199.0], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
), |
|
|
end_points=torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[44.1135, 45.7502, 19.1432], |
|
|
[151.0347, 19.5224, 30.0448], |
|
|
[186.1714, 159.3179, 47.0386], |
|
|
[6.6593, 152.2701, 29.6790], |
|
|
[43.4702, 28.3858, 161.9453], |
|
|
[177.5298, 44.2721, 170.3048], |
|
|
[185.6710, 167.6275, 185.5184], |
|
|
[22.0682, 184.1540, 157.4157], |
|
|
], |
|
|
[ |
|
|
[5.2657, 13.4747, 17.9406], |
|
|
[189.0318, 27.3596, 0.3080], |
|
|
[151.4223, 195.2367, 44.3007], |
|
|
[29.1605, 182.1176, 40.4487], |
|
|
[28.8963, 45.1991, 171.2670], |
|
|
[181.8843, 31.7171, 180.7795], |
|
|
[163.4786, 151.6794, 159.5485], |
|
|
[14.0707, 159.5684, 169.5268], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
), |
|
|
) |
|
|
assert res.keys() == expected.keys() |
|
|
assert_close(res['start_points'], expected['start_points'], atol=1e-4, rtol=1e-4) |
|
|
assert_close(res['end_points'], expected['end_points'], atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
batch_size = 2 |
|
|
res = random_perspective_generator3d( |
|
|
batch_size, 200, 200, 200, torch.tensor(0.5, device=device, dtype=dtype), same_on_batch=True |
|
|
) |
|
|
expected = dict( |
|
|
start_points=torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[0.0, 0.0, 0.0], |
|
|
[199.0, 0.0, 0.0], |
|
|
[199.0, 199.0, 0.0], |
|
|
[0.0, 199.0, 0.0], |
|
|
[0.0, 0.0, 199.0], |
|
|
[199.0, 0.0, 199.0], |
|
|
[199.0, 199.0, 199.0], |
|
|
[0.0, 199.0, 199.0], |
|
|
], |
|
|
[ |
|
|
[0.0, 0.0, 0.0], |
|
|
[199.0, 0.0, 0.0], |
|
|
[199.0, 199.0, 0.0], |
|
|
[0.0, 199.0, 0.0], |
|
|
[0.0, 0.0, 199.0], |
|
|
[199.0, 0.0, 199.0], |
|
|
[199.0, 199.0, 199.0], |
|
|
[0.0, 199.0, 199.0], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
), |
|
|
end_points=torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[44.1135, 45.7502, 19.1432], |
|
|
[151.0347, 19.5224, 30.0448], |
|
|
[186.1714, 159.3179, 47.0386], |
|
|
[6.6593, 152.2701, 29.6790], |
|
|
[43.4702, 28.3858, 161.9453], |
|
|
[177.5298, 44.2721, 170.3048], |
|
|
[185.6710, 167.6275, 185.5184], |
|
|
[22.0682, 184.1540, 157.4157], |
|
|
], |
|
|
[ |
|
|
[44.1135, 45.7502, 19.1432], |
|
|
[151.0347, 19.5224, 30.0448], |
|
|
[186.1714, 159.3179, 47.0386], |
|
|
[6.6593, 152.2701, 29.6790], |
|
|
[43.4702, 28.3858, 161.9453], |
|
|
[177.5298, 44.2721, 170.3048], |
|
|
[185.6710, 167.6275, 185.5184], |
|
|
[22.0682, 184.1540, 157.4157], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
), |
|
|
) |
|
|
assert res.keys() == expected.keys() |
|
|
assert_close(res['start_points'], expected['start_points'], atol=1e-4, rtol=1e-4) |
|
|
assert_close(res['end_points'], expected['end_points'], atol=1e-4, rtol=1e-4) |
|
|
|
|
|
|
|
|
class TestRandomAffineGen3D(RandomGeneratorBaseTests): |
|
|
@pytest.mark.parametrize('batch_size', [0, 1, 8]) |
|
|
@pytest.mark.parametrize('depth,height,width', [(200, 300, 400)]) |
|
|
@pytest.mark.parametrize('degrees', [torch.tensor([(0, 30), (0, 30), (0, 30)])]) |
|
|
@pytest.mark.parametrize('translate', [None, torch.tensor([0.1, 0.1, 0.1])]) |
|
|
@pytest.mark.parametrize('scale', [None, torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]])]) |
|
|
@pytest.mark.parametrize('shear', [None, torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20], [0, 20]])]) |
|
|
@pytest.mark.parametrize('same_on_batch', [True, False]) |
|
|
def test_valid_param_combinations( |
|
|
self, batch_size, depth, height, width, degrees, translate, scale, shear, same_on_batch, device, dtype |
|
|
): |
|
|
random_affine_generator3d( |
|
|
batch_size=batch_size, |
|
|
depth=depth, |
|
|
height=height, |
|
|
width=width, |
|
|
degrees=degrees.to(device=device, dtype=dtype), |
|
|
translate=translate.to(device=device, dtype=dtype) if translate is not None else None, |
|
|
scale=scale.to(device=device, dtype=dtype) if scale is not None else None, |
|
|
shears=shear.to(device=device, dtype=dtype) if shear is not None else None, |
|
|
same_on_batch=same_on_batch, |
|
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
'depth,height,width,degrees,translate,scale,shear', |
|
|
[ |
|
|
(-100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, None, None), |
|
|
(100, -100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, None, None), |
|
|
(100, 100, -100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, None, None), |
|
|
(100, 100, 100, torch.tensor([0, 9]), None, None, None), |
|
|
(100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), torch.tensor([0.1, 0.2]), None, None), |
|
|
(100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), torch.tensor([0.1, 0.2]), None, None), |
|
|
(100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), torch.tensor([0.1]), None, None), |
|
|
(100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, torch.tensor([[0.2, 0.2, 0.2]]), None), |
|
|
(100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, torch.tensor([0.2]), None), |
|
|
(100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, None, torch.tensor([[20, 20, 30]])), |
|
|
(100, 100, 100, torch.tensor([[0, 9], [0, 9], [0, 9]]), None, None, torch.tensor([20])), |
|
|
], |
|
|
) |
|
|
def test_invalid_param_combinations(self, depth, height, width, degrees, translate, scale, shear, device, dtype): |
|
|
with pytest.raises(Exception): |
|
|
random_affine_generator3d( |
|
|
batch_size=8, |
|
|
depth=depth, |
|
|
height=height, |
|
|
width=width, |
|
|
degrees=degrees.to(device=device, dtype=dtype), |
|
|
translate=translate.to(device=device, dtype=dtype) if translate is not None else None, |
|
|
scale=scale.to(device=device, dtype=dtype) if scale is not None else None, |
|
|
shears=shear.to(device=device, dtype=dtype) if shear is not None else None, |
|
|
) |
|
|
|
|
|
def test_random_gen(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
degrees = torch.tensor([[10, 20], [10, 20], [10, 20]]) |
|
|
translate = torch.tensor([0.1, 0.1, 0.1]) |
|
|
scale = torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]]) |
|
|
shear = torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20], [0, 20]]) |
|
|
res = random_affine_generator3d( |
|
|
batch_size=2, |
|
|
depth=200, |
|
|
height=200, |
|
|
width=200, |
|
|
degrees=degrees.to(device=device, dtype=dtype), |
|
|
translate=translate.to(device=device, dtype=dtype) if translate is not None else None, |
|
|
scale=scale.to(device=device, dtype=dtype) if scale is not None else None, |
|
|
shears=shear.to(device=device, dtype=dtype) if shear is not None else None, |
|
|
) |
|
|
expected = dict( |
|
|
translations=torch.tensor( |
|
|
[[14.7762, 9.6438, 15.4177], [2.7086, -2.8238, 2.9562]], device=device, dtype=dtype |
|
|
), |
|
|
center=torch.tensor([[99.5000, 99.5000, 99.5000], [99.5000, 99.5000, 99.5000]], device=device, dtype=dtype), |
|
|
scale=torch.tensor([[0.8283, 1.1704, 1.1673], [1.0968, 0.7666, 0.9968]], device=device, dtype=dtype), |
|
|
angles=torch.tensor([[18.8227, 13.8286, 13.9045], [19.1500, 19.5931, 16.0090]], device=device, dtype=dtype), |
|
|
sxy=torch.tensor([5.3316, 12.5490], device=device, dtype=dtype), |
|
|
sxz=torch.tensor([5.3926, 8.8273], device=device, dtype=dtype), |
|
|
syx=torch.tensor([5.9384, 16.6337], device=device, dtype=dtype), |
|
|
syz=torch.tensor([2.1063, 5.3899], device=device, dtype=dtype), |
|
|
szx=torch.tensor([7.1763, 3.9873], device=device, dtype=dtype), |
|
|
szy=torch.tensor([10.9438, 0.1232], device=device, dtype=dtype), |
|
|
) |
|
|
assert res.keys() == expected.keys() |
|
|
assert_close(res['translations'], expected['translations'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['center'], expected['center'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['angles'], expected['angles'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['sxy'], expected['sxy'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['sxz'], expected['sxz'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['syx'], expected['syx'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['syz'], expected['syz'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['szx'], expected['szx'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['szy'], expected['szy'], rtol=1e-4, atol=1e-4) |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
degrees = torch.tensor([[10, 20], [10, 20], [10, 20]]) |
|
|
translate = torch.tensor([0.1, 0.1, 0.1]) |
|
|
scale = torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]]) |
|
|
shear = torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20], [0, 20]]) |
|
|
res = random_affine_generator3d( |
|
|
batch_size=2, |
|
|
depth=200, |
|
|
height=200, |
|
|
width=200, |
|
|
degrees=degrees.to(device=device, dtype=dtype), |
|
|
translate=translate.to(device=device, dtype=dtype) if translate is not None else None, |
|
|
scale=scale.to(device=device, dtype=dtype) if scale is not None else None, |
|
|
shears=shear.to(device=device, dtype=dtype) if shear is not None else None, |
|
|
same_on_batch=True, |
|
|
) |
|
|
expected = dict( |
|
|
translations=torch.tensor( |
|
|
[[-9.7371, 11.7457, 17.6309], [-9.7371, 11.7457, 17.6309]], device=device, dtype=dtype |
|
|
), |
|
|
center=torch.tensor([[99.5000, 99.5000, 99.5000], [99.5000, 99.5000, 99.5000]], device=device, dtype=dtype), |
|
|
scale=torch.tensor([[1.1797, 0.8952, 1.0004], [1.1797, 0.8952, 1.0004]], device=device, dtype=dtype), |
|
|
angles=torch.tensor([[18.8227, 19.1500, 13.8286], [18.8227, 19.1500, 13.8286]], device=device, dtype=dtype), |
|
|
sxy=torch.tensor([2.6637, 2.6637], device=device, dtype=dtype), |
|
|
sxz=torch.tensor([18.6920, 18.6920], device=device, dtype=dtype), |
|
|
syx=torch.tensor([11.8716, 11.8716], device=device, dtype=dtype), |
|
|
syz=torch.tensor([17.3881, 17.3881], device=device, dtype=dtype), |
|
|
szx=torch.tensor([11.3543, 11.3543], device=device, dtype=dtype), |
|
|
szy=torch.tensor([14.8219, 14.8219], device=device, dtype=dtype), |
|
|
) |
|
|
assert res.keys() == expected.keys() |
|
|
assert_close(res['translations'], expected['translations'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['center'], expected['center'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['angles'], expected['angles'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['sxy'], expected['sxy'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['sxz'], expected['sxz'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['syx'], expected['syx'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['syz'], expected['syz'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['szx'], expected['szx'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['szy'], expected['szy'], rtol=1e-4, atol=1e-4) |
|
|
|
|
|
|
|
|
class TestRandomRotationGen3D(RandomGeneratorBaseTests): |
|
|
@pytest.mark.parametrize('batch_size', [0, 1, 8]) |
|
|
@pytest.mark.parametrize('degrees', [torch.tensor([[0, 30], [0, 30], [0, 30]])]) |
|
|
@pytest.mark.parametrize('same_on_batch', [True, False]) |
|
|
def test_valid_param_combinations(self, batch_size, degrees, same_on_batch, device, dtype): |
|
|
random_rotation_generator3d( |
|
|
batch_size=batch_size, degrees=degrees.to(device=device, dtype=dtype), same_on_batch=same_on_batch |
|
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
'degrees', |
|
|
[(torch.tensor(10)), (torch.tensor([10])), (torch.tensor([[0, 30]])), (torch.tensor([[0, 30], [0, 30]]))], |
|
|
) |
|
|
def test_invalid_param_combinations(self, degrees, device, dtype): |
|
|
with pytest.raises(Exception): |
|
|
random_rotation_generator3d(batch_size=8, degrees=degrees.to(device=device, dtype=dtype)) |
|
|
|
|
|
def test_random_gen(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
degrees = torch.tensor([[0, 30], [0, 30], [0, 30]]) |
|
|
res = random_rotation_generator3d( |
|
|
batch_size=2, degrees=degrees.to(device=device, dtype=dtype), same_on_batch=False |
|
|
) |
|
|
expected = dict( |
|
|
yaw=torch.tensor([26.4681, 27.4501], device=device, dtype=dtype), |
|
|
pitch=torch.tensor([11.4859, 28.7792], device=device, dtype=dtype), |
|
|
roll=torch.tensor([11.7134, 18.0269], device=device, dtype=dtype), |
|
|
) |
|
|
assert res.keys() == expected.keys() |
|
|
assert_close(res['yaw'], expected['yaw'], atol=1e-4, rtol=1e-4) |
|
|
assert_close(res['pitch'], expected['pitch'], atol=1e-4, rtol=1e-4) |
|
|
assert_close(res['roll'], expected['roll'], atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
degrees = torch.tensor([[0, 30], [0, 30], [0, 30]]) |
|
|
res = random_rotation_generator3d( |
|
|
batch_size=2, degrees=degrees.to(device=device, dtype=dtype), same_on_batch=True |
|
|
) |
|
|
expected = dict( |
|
|
yaw=torch.tensor([26.4681, 26.4681], device=device, dtype=dtype), |
|
|
pitch=torch.tensor([27.4501, 27.4501], device=device, dtype=dtype), |
|
|
roll=torch.tensor([11.4859, 11.4859], device=device, dtype=dtype), |
|
|
) |
|
|
assert res.keys() == expected.keys() |
|
|
assert_close(res['yaw'], expected['yaw'], atol=1e-4, rtol=1e-4) |
|
|
assert_close(res['pitch'], expected['pitch'], atol=1e-4, rtol=1e-4) |
|
|
assert_close(res['roll'], expected['roll'], atol=1e-4, rtol=1e-4) |
|
|
|
|
|
|
|
|
class TestRandomCropGen3D(RandomGeneratorBaseTests): |
|
|
@pytest.mark.parametrize('batch_size', [0, 2]) |
|
|
@pytest.mark.parametrize('input_size', [(200, 200, 200)]) |
|
|
@pytest.mark.parametrize('size', [(100, 100, 100), torch.tensor([50, 60, 70])]) |
|
|
@pytest.mark.parametrize('resize_to', [None, (100, 100, 100)]) |
|
|
@pytest.mark.parametrize('same_on_batch', [True, False]) |
|
|
def test_valid_param_combinations(self, batch_size, input_size, size, resize_to, same_on_batch, device, dtype): |
|
|
if isinstance(size, torch.Tensor): |
|
|
size = size.repeat(batch_size, 1).to(device=device, dtype=dtype) |
|
|
random_crop_generator3d( |
|
|
batch_size=batch_size, |
|
|
input_size=input_size, |
|
|
size=size.to(device=device, dtype=dtype) if isinstance(size, torch.Tensor) else size, |
|
|
resize_to=resize_to, |
|
|
same_on_batch=same_on_batch, |
|
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
'input_size,size,resize_to', |
|
|
[ |
|
|
((-300, 300, 300), (200, 200, 200), (100, 100, 100)), |
|
|
((100, 100, 100), (200, 200, 200), (100, 100, 100)), |
|
|
((200, 200, 200), torch.tensor([50, 50, 50]), (100, 100, 100)), |
|
|
((100, 100, 100), torch.tensor([[50, 60, 70], [50, 60, 70]]), (100, 100)), |
|
|
], |
|
|
) |
|
|
def test_invalid_param_combinations(self, input_size, size, resize_to, device, dtype): |
|
|
with pytest.raises(Exception): |
|
|
random_crop_generator3d( |
|
|
batch_size=2, |
|
|
input_size=input_size, |
|
|
size=size.to(device=device, dtype=dtype) if isinstance(size, torch.Tensor) else size, |
|
|
resize_to=resize_to, |
|
|
) |
|
|
|
|
|
def test_random_gen(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
res = random_crop_generator3d( |
|
|
batch_size=2, |
|
|
input_size=(200, 200, 200), |
|
|
size=torch.tensor([[50, 60, 70], [50, 60, 70]], device=device, dtype=dtype), |
|
|
resize_to=(100, 100, 100), |
|
|
) |
|
|
expected = dict( |
|
|
src=torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[115, 53, 58], |
|
|
[184, 53, 58], |
|
|
[184, 112, 58], |
|
|
[115, 112, 58], |
|
|
[115, 53, 107], |
|
|
[184, 53, 107], |
|
|
[184, 112, 107], |
|
|
[115, 112, 107], |
|
|
], |
|
|
[ |
|
|
[119, 135, 90], |
|
|
[188, 135, 90], |
|
|
[188, 194, 90], |
|
|
[119, 194, 90], |
|
|
[119, 135, 139], |
|
|
[188, 135, 139], |
|
|
[188, 194, 139], |
|
|
[119, 194, 139], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
), |
|
|
dst=torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[0, 0, 0], |
|
|
[99, 0, 0], |
|
|
[99, 99, 0], |
|
|
[0, 99, 0], |
|
|
[0, 0, 99], |
|
|
[99, 0, 99], |
|
|
[99, 99, 99], |
|
|
[0, 99, 99], |
|
|
], |
|
|
[ |
|
|
[0, 0, 0], |
|
|
[99, 0, 0], |
|
|
[99, 99, 0], |
|
|
[0, 99, 0], |
|
|
[0, 0, 99], |
|
|
[99, 0, 99], |
|
|
[99, 99, 99], |
|
|
[0, 99, 99], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
), |
|
|
) |
|
|
assert res.keys() == expected.keys() |
|
|
assert_close(res['src'], expected['src'], atol=1e-4, rtol=1e-4) |
|
|
assert_close(res['dst'], expected['dst'], atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
res = random_crop_generator3d( |
|
|
batch_size=2, |
|
|
input_size=(200, 200, 200), |
|
|
size=torch.tensor([[50, 60, 70], [50, 60, 70]], device=device, dtype=dtype), |
|
|
resize_to=(100, 100, 100), |
|
|
same_on_batch=True, |
|
|
) |
|
|
expected = dict( |
|
|
src=torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[115, 129, 57], |
|
|
[184, 129, 57], |
|
|
[184, 188, 57], |
|
|
[115, 188, 57], |
|
|
[115, 129, 106], |
|
|
[184, 129, 106], |
|
|
[184, 188, 106], |
|
|
[115, 188, 106], |
|
|
], |
|
|
[ |
|
|
[115, 129, 57], |
|
|
[184, 129, 57], |
|
|
[184, 188, 57], |
|
|
[115, 188, 57], |
|
|
[115, 129, 106], |
|
|
[184, 129, 106], |
|
|
[184, 188, 106], |
|
|
[115, 188, 106], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
), |
|
|
dst=torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[0, 0, 0], |
|
|
[99, 0, 0], |
|
|
[99, 99, 0], |
|
|
[0, 99, 0], |
|
|
[0, 0, 99], |
|
|
[99, 0, 99], |
|
|
[99, 99, 99], |
|
|
[0, 99, 99], |
|
|
], |
|
|
[ |
|
|
[0, 0, 0], |
|
|
[99, 0, 0], |
|
|
[99, 99, 0], |
|
|
[0, 99, 0], |
|
|
[0, 0, 99], |
|
|
[99, 0, 99], |
|
|
[99, 99, 99], |
|
|
[0, 99, 99], |
|
|
], |
|
|
], |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
), |
|
|
) |
|
|
assert res.keys() == expected.keys() |
|
|
assert_close(res['src'], expected['src'], atol=1e-4, rtol=1e-4) |
|
|
assert_close(res['dst'], expected['dst'], atol=1e-4, rtol=1e-4) |
|
|
|
|
|
|
|
|
class TestCenterCropGen3D(RandomGeneratorBaseTests): |
|
|
@pytest.mark.parametrize('batch_size', [0, 2]) |
|
|
@pytest.mark.parametrize('depth,height,width', [(200, 200, 200)]) |
|
|
@pytest.mark.parametrize('size', [(100, 100, 100)]) |
|
|
def test_valid_param_combinations(self, batch_size, depth, height, width, size, device, dtype): |
|
|
center_crop_generator3d(batch_size=batch_size, depth=depth, height=height, width=width, size=size) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
'depth,height,width,size', |
|
|
[ |
|
|
(200, 200, -200, (100, 100, 100)), |
|
|
(200, -200, 200, (100, 100)), |
|
|
(200, 100, 100, (300, 120, 100)), |
|
|
(200, 150, 100, (120, 180, 100)), |
|
|
(200, 100, 150, (120, 80, 200)), |
|
|
], |
|
|
) |
|
|
def test_invalid_param_combinations(self, depth, height, width, size, device, dtype): |
|
|
with pytest.raises(Exception): |
|
|
center_crop_generator3d(batch_size=2, depth=depth, height=height, width=width, size=size) |
|
|
|
|
|
def test_random_gen(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
res = center_crop_generator3d(batch_size=2, depth=200, height=200, width=200, size=(120, 150, 100)) |
|
|
expected = dict( |
|
|
src=torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[50, 25, 40], |
|
|
[149, 25, 40], |
|
|
[149, 174, 40], |
|
|
[50, 174, 40], |
|
|
[50, 25, 159], |
|
|
[149, 25, 159], |
|
|
[149, 174, 159], |
|
|
[50, 174, 159], |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=torch.long, |
|
|
).repeat(2, 1, 1), |
|
|
dst=torch.tensor( |
|
|
[ |
|
|
[ |
|
|
[0, 0, 0], |
|
|
[99, 0, 0], |
|
|
[99, 149, 0], |
|
|
[0, 149, 0], |
|
|
[0, 0, 119], |
|
|
[99, 0, 119], |
|
|
[99, 149, 119], |
|
|
[0, 149, 119], |
|
|
] |
|
|
], |
|
|
device=device, |
|
|
dtype=torch.long, |
|
|
).repeat(2, 1, 1), |
|
|
) |
|
|
assert res.keys() == expected.keys() |
|
|
assert_close(res['src'].to(device=device), expected['src'], atol=1e-4, rtol=1e-4) |
|
|
assert_close(res['dst'].to(device=device), expected['dst'], atol=1e-4, rtol=1e-4) |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
pass |
|
|
|
|
|
|
|
|
class TestRandomMotionBlur3D(RandomGeneratorBaseTests): |
|
|
@pytest.mark.parametrize('batch_size', [0, 1, 8]) |
|
|
@pytest.mark.parametrize('kernel_size', [3, (3, 5)]) |
|
|
@pytest.mark.parametrize('angle', [torch.tensor([(10, 30), (30, 60), (60, 90)])]) |
|
|
@pytest.mark.parametrize('direction', [torch.tensor([-1, -1]), torch.tensor([-1, 1]), torch.tensor([1, 1])]) |
|
|
@pytest.mark.parametrize('same_on_batch', [True, False]) |
|
|
def test_valid_param_combinations(self, batch_size, kernel_size, angle, direction, same_on_batch, device, dtype): |
|
|
random_motion_blur_generator3d( |
|
|
batch_size=batch_size, |
|
|
kernel_size=kernel_size, |
|
|
angle=angle.to(device=device, dtype=dtype), |
|
|
direction=direction.to(device=device, dtype=dtype), |
|
|
same_on_batch=same_on_batch, |
|
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
'kernel_size,angle,direction', |
|
|
[ |
|
|
(4, torch.tensor([(10, 30), (30, 60), (60, 90)]), torch.tensor([-1, 1])), |
|
|
(1, torch.tensor([(10, 30), (30, 60), (60, 90)]), torch.tensor([-1, 1])), |
|
|
((3, 4, 5), torch.tensor([(10, 30), (30, 60), (60, 90)]), torch.tensor([-1, 1])), |
|
|
(3, torch.tensor([(10, 30), (30, 60), (60, 90)]), torch.tensor([-2, 1])), |
|
|
(3, torch.tensor([(10, 30), (30, 60), (60, 90)]), torch.tensor([-1, 2])), |
|
|
], |
|
|
) |
|
|
def test_invalid_param_combinations(self, kernel_size, angle, direction, device, dtype): |
|
|
with pytest.raises(Exception): |
|
|
random_motion_blur_generator3d( |
|
|
batch_size=8, |
|
|
kernel_size=kernel_size, |
|
|
angle=angle.to(device=device, dtype=dtype), |
|
|
direction=direction.to(device=device, dtype=dtype), |
|
|
) |
|
|
|
|
|
def test_random_gen(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
angle = torch.tensor([(10, 30), (30, 60), (60, 90)]) |
|
|
direction = torch.tensor([-1, 1]) |
|
|
res = random_motion_blur_generator3d( |
|
|
batch_size=2, |
|
|
kernel_size=3, |
|
|
angle=angle.to(device=device, dtype=dtype), |
|
|
direction=direction.to(device=device, dtype=dtype), |
|
|
same_on_batch=False, |
|
|
) |
|
|
expected = dict( |
|
|
ksize_factor=torch.tensor([3, 3], device=device, dtype=torch.int32), |
|
|
angle_factor=torch.tensor( |
|
|
[[27.6454, 41.4859, 71.7134], [28.3001, 58.7792, 78.0269]], device=device, dtype=dtype |
|
|
), |
|
|
direction_factor=torch.tensor([-0.4869, 0.5873], device=device, dtype=dtype), |
|
|
) |
|
|
assert res.keys() == expected.keys() |
|
|
assert_close(res['ksize_factor'], expected['ksize_factor'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['angle_factor'], expected['angle_factor'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['direction_factor'], expected['direction_factor'], rtol=1e-4, atol=1e-4) |
|
|
|
|
|
def test_same_on_batch(self, device, dtype): |
|
|
torch.manual_seed(42) |
|
|
angle = torch.tensor([(10, 30), (30, 60), (60, 90)]) |
|
|
direction = torch.tensor([-1, 1]) |
|
|
res = random_motion_blur_generator3d( |
|
|
batch_size=2, |
|
|
kernel_size=3, |
|
|
angle=angle.to(device=device, dtype=dtype), |
|
|
direction=direction.to(device=device, dtype=dtype), |
|
|
same_on_batch=True, |
|
|
) |
|
|
expected = dict( |
|
|
ksize_factor=torch.tensor([3, 3], device=device, dtype=torch.int32), |
|
|
angle_factor=torch.tensor( |
|
|
[[27.6454, 57.4501, 71.4859], [27.6454, 57.4501, 71.4859]], device=device, dtype=dtype |
|
|
), |
|
|
direction_factor=torch.tensor([0.9186, 0.9186], device=device, dtype=dtype), |
|
|
) |
|
|
assert res.keys() == expected.keys() |
|
|
assert_close(res['ksize_factor'], expected['ksize_factor'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['angle_factor'], expected['angle_factor'], rtol=1e-4, atol=1e-4) |
|
|
assert_close(res['direction_factor'], expected['direction_factor'], rtol=1e-4, atol=1e-4) |
|
|
|