File size: 4,580 Bytes
62c7319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b973ee
 
 
62c7319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b973ee
62c7319
 
8b973ee
 
62c7319
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Dict
import numpy as np
import torch
import kornia.augmentation as K
from kornia.geometry.transform import warp_perspective

# Adapted from Kornia
class GeometricSequential:
    def __init__(self, *transforms, align_corners=True) -> None:
        self.transforms = transforms
        self.align_corners = align_corners

    def __call__(self, x, mode="bilinear"):
        b, c, h, w = x.shape
        M = torch.eye(3, device=x.device)[None].expand(b, 3, 3)
        for t in self.transforms:
            if np.random.rand() < t.p:
                M = M.matmul(
                    t.compute_transformation(
                        x, t.generate_parameters((b, c, h, w)), None
                    )
                )
        return (
            warp_perspective(
                x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners
            ),
            M,
        )

    def apply_transform(self, x, M, mode="bilinear"):
        b, c, h, w = x.shape
        return warp_perspective(
            x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode
        )


class RandomPerspective(K.RandomPerspective):
    def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
        distortion_scale = torch.as_tensor(
            self.distortion_scale, device=self._device, dtype=self._dtype
        )
        return self.random_perspective_generator(
            batch_shape[0],
            batch_shape[-2],
            batch_shape[-1],
            distortion_scale,
            self.same_on_batch,
            self.device,
            self.dtype,
        )

    def random_perspective_generator(
        self,
        batch_size: int,
        height: int,
        width: int,
        distortion_scale: torch.Tensor,
        same_on_batch: bool = False,
        device: torch.device = torch.device("cpu"),
        dtype: torch.dtype = torch.float32,
    ) -> Dict[str, torch.Tensor]:
        r"""Get parameters for ``perspective`` for a random perspective transform.

        Args:
            batch_size (int): the tensor batch size.
            height (int) : height of the image.
            width (int): width of the image.
            distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1.
            same_on_batch (bool): apply the same transformation across the batch. Default: False.
            device (torch.device): the device on which the random numbers will be generated. Default: cpu.
            dtype (torch.dtype): the data type of the generated random numbers. Default: float32.

        Returns:
            params Dict[str, torch.Tensor]: parameters to be passed for transformation.
                - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2).
                - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2).

        Note:
            The generated random numbers are not reproducible across different devices and dtypes.
        """
        if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1):
            raise AssertionError(
                f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}."
            )
        if not (
            type(height) is int and height > 0 and type(width) is int and width > 0
        ):
            raise AssertionError(
                f"'height' and 'width' must be integers. Got {height}, {width}."
            )

        start_points: torch.Tensor = torch.tensor(
            [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]],
            device=distortion_scale.device,
            dtype=distortion_scale.dtype,
        ).expand(batch_size, -1, -1)

        # generate random offset not larger than half of the image
        fx = distortion_scale * width / 2
        fy = distortion_scale * height / 2

        factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2)
        offset = (torch.rand_like(start_points) - 0.5) * 2
        end_points = start_points + factor * offset

        return dict(start_points=start_points, end_points=end_points)


class RandomErasing:
    def __init__(self, p=0.0, scale=0.0) -> None:
        self.p = p
        self.scale = scale
        self.random_eraser = K.RandomErasing(scale=(0.02, scale), p=p)

    def __call__(self, image, depth):
        if self.p > 0:
            image = self.random_eraser(image)
            depth = self.random_eraser(depth, params=self.random_eraser._params)
        return image, depth