File size: 5,005 Bytes
16d007c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import warnings
from collections.abc import Sequence
import numbers
import torchvision.transforms as T
from torchvision.transforms.functional import (
    InterpolationMode,
    _interpolation_modes_from_int,
    get_image_num_channels,
    get_image_size,
    perspective,
    crop,
)
import torch
import numpy as np


class RandomScale(object):
    def __init__(self, scale_range=(0.8, 1.2), min_size=None):
        super(RandomScale, self).__init__()
        self.scale_range = scale_range
        self.min_size = min_size if min_size is not None else 0

    def __call__(self, img):
        if isinstance(img, torch.Tensor):
            height, width = img.shape[-2:]
        else:
            width, height = img.size[-2:]
        s = np.random.uniform(*self.scale_range)
        resize_h = max(int(height * s), self.min_size)
        resize_w = max(int(width * s), self.min_size)
        size = (resize_h, resize_w)
        return T.Resize(size)(img)


class RandomSizeCrop(object):
    def __init__(self, min_cover):
        super(RandomSizeCrop, self).__init__()
        self.min_cover = min_cover

    def __call__(self, img):
        if self.min_cover == 1:
            return img
        if isinstance(img, torch.Tensor):
            h, w = img.shape[-2:]
        else:
            w, h = img.size[-2:]
        s = np.random.uniform(self.min_cover, 1)
        size_h = int(h * s)
        size_w = int(w * s)
        return T.RandomCrop((size_h, size_w))(img)


class DivisibleCrop(object):
    def __init__(self, d):
        super(DivisibleCrop, self).__init__()
        self.d = d

    def __call__(self, img):
        if isinstance(img, torch.Tensor):
            h, w = img.shape[-2:]
        else:
            w, h = img.size[-2:]

        h = h - h % self.d
        w = w - w % self.d
        return T.CenterCrop((h, w))(img)


class ToTensorSafe(object):
    def __init__(self):
        super(ToTensorSafe, self).__init__()

    def __call__(self, img):
        if isinstance(img, torch.Tensor):
            return img
        return T.ToTensor()(img)


class BorderlessRandomPerspective(object):
    """Applies random perspective and crops the image to be without borders

    Args:
        distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
            Default is 0.5.
        p (float): probability of the image being transformed. Default is 0.5.
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
    """

    def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
        super().__init__()
        self.p = p

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
            )
            interpolation = _interpolation_modes_from_int(interpolation)

        self.interpolation = interpolation
        self.distortion_scale = distortion_scale

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

        self.fill = fill

    @staticmethod
    def get_crop_endpoints(endpoints):
        topleft, topright, botright, botleft = endpoints
        topy = max(topleft[1], topright[1])
        leftx = max(topleft[0], botleft[0])
        boty = min(botleft[1], botright[1])
        rightx = min(topright[0], botright[0])

        h = boty - topy
        w = rightx - leftx
        return topy, leftx, h, w

    def __call__(self, img):
        fill = self.fill
        if isinstance(img, torch.Tensor):
            if isinstance(fill, (int, float)):
                fill = [float(fill)] * get_image_num_channels(img)
            else:
                fill = [float(f) for f in fill]

        if torch.rand(1) < self.p:
            width, height = get_image_size(img)
            startpoints, endpoints = T.RandomPerspective.get_params(width, height, self.distortion_scale)
            warped = perspective(img, startpoints, endpoints, self.interpolation, fill)
            i, j, h, w = self.get_crop_endpoints(endpoints)
            # print(f"Crop size: {h,w}")
            cropped = crop(warped, i, j, h, w)
            return T.Compose([T.Resize(224), T.CenterCrop(224)])(cropped)
        return img