| | from typing import Tuple |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from kornia.filters import filter2d, get_gaussian_kernel2d |
| | from kornia.utils import create_meshgrid |
| |
|
| | __all__ = ["elastic_transform2d"] |
| |
|
| |
|
| | def elastic_transform2d( |
| | image: torch.Tensor, |
| | noise: torch.Tensor, |
| | kernel_size: Tuple[int, int] = (63, 63), |
| | sigma: Tuple[float, float] = (32.0, 32.0), |
| | alpha: Tuple[float, float] = (1.0, 1.0), |
| | align_corners: bool = False, |
| | mode: str = 'bilinear', |
| | ) -> torch.Tensor: |
| | r"""Apply elastic transform of images as described in :cite:`Simard2003BestPF`. |
| | |
| | .. image:: _static/img/elastic_transform2d.png |
| | |
| | Args: |
| | image: Input image to be transformed with shape :math:`(B, C, H, W)`. |
| | noise: Noise image used to spatially transform the input image. Same |
| | resolution as the input image with shape :math:`(B, 2, H, W)`. The coordinates order |
| | it is expected to be in x-y. |
| | kernel_size: the size of the Gaussian kernel. |
| | sigma: The standard deviation of the Gaussian in the y and x directions, |
| | respectively. Larger sigma results in smaller pixel displacements. |
| | alpha : The scaling factor that controls the intensity of the deformation |
| | in the y and x directions, respectively. |
| | align_corners: Interpolation flag used by ```grid_sample```. |
| | mode: Interpolation mode used by ```grid_sample```. Either ``'bilinear'`` or ``'nearest'``. |
| | |
| | .. note: |
| | ```sigma``` and ```alpha``` can also be a ``torch.Tensor``. However, you could not torchscript |
| | this function with tensor until PyTorch 1.8 is released. |
| | |
| | Returns: |
| | the elastically transformed input image with shape :math:`(B,C,H,W)`. |
| | |
| | Example: |
| | >>> image = torch.rand(1, 3, 5, 5) |
| | >>> noise = torch.rand(1, 2, 5, 5, requires_grad=True) |
| | >>> image_hat = elastic_transform2d(image, noise, (3, 3)) |
| | >>> image_hat.mean().backward() |
| | |
| | >>> image = torch.rand(1, 3, 5, 5) |
| | >>> noise = torch.rand(1, 2, 5, 5) |
| | >>> sigma = torch.tensor([4., 4.], requires_grad=True) |
| | >>> image_hat = elastic_transform2d(image, noise, (3, 3), sigma) |
| | >>> image_hat.mean().backward() |
| | |
| | >>> image = torch.rand(1, 3, 5, 5) |
| | >>> noise = torch.rand(1, 2, 5, 5) |
| | >>> alpha = torch.tensor([16., 32.], requires_grad=True) |
| | >>> image_hat = elastic_transform2d(image, noise, (3, 3), alpha=alpha) |
| | >>> image_hat.mean().backward() |
| | """ |
| | if not isinstance(image, torch.Tensor): |
| | raise TypeError(f"Input image is not torch.Tensor. Got {type(image)}") |
| |
|
| | if not isinstance(noise, torch.Tensor): |
| | raise TypeError(f"Input noise is not torch.Tensor. Got {type(noise)}") |
| |
|
| | if not len(image.shape) == 4: |
| | raise ValueError(f"Invalid image shape, we expect BxCxHxW. Got: {image.shape}") |
| |
|
| | if not len(noise.shape) == 4 or noise.shape[1] != 2: |
| | raise ValueError(f"Invalid noise shape, we expect Bx2xHxW. Got: {noise.shape}") |
| |
|
| | |
| | kernel_x: torch.Tensor = get_gaussian_kernel2d(kernel_size, (sigma[0], sigma[0]))[None] |
| | kernel_y: torch.Tensor = get_gaussian_kernel2d(kernel_size, (sigma[1], sigma[1]))[None] |
| |
|
| | |
| | disp_x: torch.Tensor = noise[:, :1] |
| | disp_y: torch.Tensor = noise[:, 1:] |
| |
|
| | disp_x = filter2d(disp_x, kernel=kernel_y, border_type='constant') * alpha[0] |
| | disp_y = filter2d(disp_y, kernel=kernel_x, border_type='constant') * alpha[1] |
| |
|
| | |
| | disp = torch.cat([disp_x, disp_y], dim=1).permute(0, 2, 3, 1) |
| |
|
| | |
| | _, _, h, w = image.shape |
| | grid = create_meshgrid(h, w, device=image.device).to(image.dtype) |
| | warped = F.grid_sample(image, (grid + disp).clamp(-1, 1), align_corners=align_corners, mode=mode) |
| |
|
| | return warped |
| |
|