File size: 4,817 Bytes
753fd9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Modified from:
#   https://github.com/anibali/pytorch-stacked-hourglass 
#   https://github.com/bearpaw/pytorch-pose

import torch
from torch.nn.functional import interpolate


def _resize(tensor, size, mode='bilinear'):
    """Resize the image.

    Args:
        tensor (torch.Tensor): The image tensor to be resized.
        size (tuple of int): Size of the resized image (height, width).
        mode (str): The pixel sampling interpolation mode to be used.

    Returns:
        Tensor: The resized image tensor.
    """
    assert len(size) == 2

    # If the tensor is already the desired size, return it immediately.
    if tensor.shape[-2] == size[0] and tensor.shape[-1] == size[1]:
        return tensor

    if not tensor.is_floating_point():
        dtype = tensor.dtype
        tensor = tensor.to(torch.float32)
        tensor = _resize(tensor, size, mode)
        return tensor.to(dtype)

    out_shape = (*tensor.shape[:-2], *size)
    if tensor.ndimension() < 3:
        raise Exception('tensor must be at least 2D')
    elif tensor.ndimension() == 3:
        tensor = tensor.unsqueeze(0)
    elif tensor.ndimension() > 4:
        tensor = tensor.view(-1, *tensor.shape[-3:])
    align_corners = None
    if mode in {'linear', 'bilinear', 'trilinear'}:
        align_corners = False
    resized = interpolate(tensor, size=size, mode=mode, align_corners=align_corners)
    return resized.view(*out_shape)


def _crop(tensor, t, l, h, w, padding_mode='constant', fill=0):
    """Crop the image, padding out-of-bounds regions.

    Args:
        tensor (torch.Tensor): The image tensor to be cropped.
        t (int): Top pixel coordinate.
        l (int): Left pixel coordinate.
        h (int): Height of the cropped image.
        w (int): Width of the cropped image.
        padding_mode (str): Padding mode (currently "constant" is the only valid option).
        fill (float): Fill value to use with constant padding.

    Returns:
        Tensor: The cropped image tensor.
    """
    # If the _crop region is wholly within the image, simply narrow the tensor.
    if t >= 0 and l >= 0 and t + h <= tensor.size(-2) and l + w <= tensor.size(-1):
        return tensor[..., t:t+h, l:l+w]

    if padding_mode == 'constant':
        result = torch.full((*tensor.size()[:-2], h, w), fill,
                            device=tensor.device, dtype=tensor.dtype)
    else:
        raise Exception('_crop only supports "constant" padding currently.')

    sx1 = l
    sy1 = t
    sx2 = l + w
    sy2 = t + h
    dx1 = 0
    dy1 = 0

    if sx1 < 0:
        dx1 = -sx1
        w += sx1
        sx1 = 0

    if sy1 < 0:
        dy1 = -sy1
        h += sy1
        sy1 = 0

    if sx2 >= tensor.size(-1):
        w -= sx2 - tensor.size(-1)

    if sy2 >= tensor.size(-2):
        h -= sy2 - tensor.size(-2)

    # Copy the in-bounds sub-area of the _crop region into the result tensor.
    if h > 0 and w > 0:
        src = tensor.narrow(-2, sy1, h).narrow(-1, sx1, w)
        dst = result.narrow(-2, dy1, h).narrow(-1, dx1, w)
        dst.copy_(src)

    return result


def calculate_fit_contain_output_area(in_height, in_width, out_height, out_width):
    ih, iw = in_height, in_width
    k = min(out_width / iw, out_height / ih)
    oh = round(k * ih)
    ow = round(k * iw)
    y_off = (out_height - oh) // 2
    x_off = (out_width - ow) // 2
    return y_off, x_off, oh, ow


def fit(tensor, size, fit_mode='cover', resize_mode='bilinear', *, fill=0):
    """Fit the image within the given spatial dimensions.

    Args:
        tensor (torch.Tensor): The image tensor to be fit.
        size (tuple of int): Size of the output (height, width).
        fit_mode (str): 'fill', 'contain', or 'cover'. These behave in the same way as CSS's
                        `object-fit` property.
        fill (float): padding value (only applicable in 'contain' mode).

    Returns:
        Tensor: The resized image tensor.
    """
    if fit_mode == 'fill':
        return _resize(tensor, size, mode=resize_mode)
    elif fit_mode == 'contain':
        y_off, x_off, oh, ow = calculate_fit_contain_output_area(*tensor.shape[-2:], *size)
        resized = _resize(tensor, (oh, ow), mode=resize_mode)
        result = tensor.new_full((*tensor.size()[:-2], *size), fill)
        result[..., y_off:y_off + oh, x_off:x_off + ow] = resized
        return result
    elif fit_mode == 'cover':
        ih, iw = tensor.shape[-2:]
        k = max(size[-1] / iw, size[-2] / ih)
        oh = round(k * ih)
        ow = round(k * iw)
        resized = _resize(tensor, (oh, ow), mode=resize_mode)
        y_trim = (oh - size[-2]) // 2
        x_trim = (ow - size[-1]) // 2
        result = _crop(resized, y_trim, x_trim, size[-2], size[-1])
        return result
    raise ValueError('Invalid fit_mode: ' + repr(fit_mode))