from typing import Dict import numpy as np import torch import kornia.augmentation as K from kornia.geometry.transform import warp_perspective # Adapted from Kornia class GeometricSequential: def __init__(self, *transforms, align_corners=True) -> None: self.transforms = transforms self.align_corners = align_corners def __call__(self, x, mode="bilinear"): b, c, h, w = x.shape M = torch.eye(3, device=x.device)[None].expand(b, 3, 3) for t in self.transforms: if np.random.rand() < t.p: M = M.matmul( t.compute_transformation(x, t.generate_parameters((b, c, h, w))) ) return ( warp_perspective( x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners ), M, ) def apply_transform(self, x, M, mode="bilinear"): b, c, h, w = x.shape return warp_perspective( x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode ) class RandomPerspective(K.RandomPerspective): def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]: distortion_scale = torch.as_tensor( self.distortion_scale, device=self._device, dtype=self._dtype ) return self.random_perspective_generator( batch_shape[0], batch_shape[-2], batch_shape[-1], distortion_scale, self.same_on_batch, self.device, self.dtype, ) def random_perspective_generator( self, batch_size: int, height: int, width: int, distortion_scale: torch.Tensor, same_on_batch: bool = False, device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.float32, ) -> Dict[str, torch.Tensor]: r"""Get parameters for ``perspective`` for a random perspective transform. Args: batch_size (int): the tensor batch size. height (int) : height of the image. width (int): width of the image. distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1. same_on_batch (bool): apply the same transformation across the batch. Default: False. device (torch.device): the device on which the random numbers will be generated. Default: cpu. dtype (torch.dtype): the data type of the generated random numbers. Default: float32. Returns: params Dict[str, torch.Tensor]: parameters to be passed for transformation. - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2). - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2). Note: The generated random numbers are not reproducible across different devices and dtypes. """ if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1): raise AssertionError( f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}." ) if not ( type(height) is int and height > 0 and type(width) is int and width > 0 ): raise AssertionError( f"'height' and 'width' must be integers. Got {height}, {width}." ) start_points: torch.Tensor = torch.tensor( [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]], device=distortion_scale.device, dtype=distortion_scale.dtype, ).expand(batch_size, -1, -1) # generate random offset not larger than half of the image fx = distortion_scale * width / 2 fy = distortion_scale * height / 2 factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2) offset = (torch.rand_like(start_points) - 0.5) * 2 end_points = start_points + factor * offset return dict(start_points=start_points, end_points=end_points)