# Modified from: # https://github.com/anibali/pytorch-stacked-hourglass # https://github.com/bearpaw/pytorch-pose import torch from torch.nn.functional import interpolate def _resize(tensor, size, mode='bilinear'): """Resize the image. Args: tensor (torch.Tensor): The image tensor to be resized. size (tuple of int): Size of the resized image (height, width). mode (str): The pixel sampling interpolation mode to be used. Returns: Tensor: The resized image tensor. """ assert len(size) == 2 # If the tensor is already the desired size, return it immediately. if tensor.shape[-2] == size[0] and tensor.shape[-1] == size[1]: return tensor if not tensor.is_floating_point(): dtype = tensor.dtype tensor = tensor.to(torch.float32) tensor = _resize(tensor, size, mode) return tensor.to(dtype) out_shape = (*tensor.shape[:-2], *size) if tensor.ndimension() < 3: raise Exception('tensor must be at least 2D') elif tensor.ndimension() == 3: tensor = tensor.unsqueeze(0) elif tensor.ndimension() > 4: tensor = tensor.view(-1, *tensor.shape[-3:]) align_corners = None if mode in {'linear', 'bilinear', 'trilinear'}: align_corners = False resized = interpolate(tensor, size=size, mode=mode, align_corners=align_corners) return resized.view(*out_shape) def _crop(tensor, t, l, h, w, padding_mode='constant', fill=0): """Crop the image, padding out-of-bounds regions. Args: tensor (torch.Tensor): The image tensor to be cropped. t (int): Top pixel coordinate. l (int): Left pixel coordinate. h (int): Height of the cropped image. w (int): Width of the cropped image. padding_mode (str): Padding mode (currently "constant" is the only valid option). fill (float): Fill value to use with constant padding. Returns: Tensor: The cropped image tensor. """ # If the _crop region is wholly within the image, simply narrow the tensor. if t >= 0 and l >= 0 and t + h <= tensor.size(-2) and l + w <= tensor.size(-1): return tensor[..., t:t+h, l:l+w] if padding_mode == 'constant': result = torch.full((*tensor.size()[:-2], h, w), fill, device=tensor.device, dtype=tensor.dtype) else: raise Exception('_crop only supports "constant" padding currently.') sx1 = l sy1 = t sx2 = l + w sy2 = t + h dx1 = 0 dy1 = 0 if sx1 < 0: dx1 = -sx1 w += sx1 sx1 = 0 if sy1 < 0: dy1 = -sy1 h += sy1 sy1 = 0 if sx2 >= tensor.size(-1): w -= sx2 - tensor.size(-1) if sy2 >= tensor.size(-2): h -= sy2 - tensor.size(-2) # Copy the in-bounds sub-area of the _crop region into the result tensor. if h > 0 and w > 0: src = tensor.narrow(-2, sy1, h).narrow(-1, sx1, w) dst = result.narrow(-2, dy1, h).narrow(-1, dx1, w) dst.copy_(src) return result def calculate_fit_contain_output_area(in_height, in_width, out_height, out_width): ih, iw = in_height, in_width k = min(out_width / iw, out_height / ih) oh = round(k * ih) ow = round(k * iw) y_off = (out_height - oh) // 2 x_off = (out_width - ow) // 2 return y_off, x_off, oh, ow def fit(tensor, size, fit_mode='cover', resize_mode='bilinear', *, fill=0): """Fit the image within the given spatial dimensions. Args: tensor (torch.Tensor): The image tensor to be fit. size (tuple of int): Size of the output (height, width). fit_mode (str): 'fill', 'contain', or 'cover'. These behave in the same way as CSS's `object-fit` property. fill (float): padding value (only applicable in 'contain' mode). Returns: Tensor: The resized image tensor. """ if fit_mode == 'fill': return _resize(tensor, size, mode=resize_mode) elif fit_mode == 'contain': y_off, x_off, oh, ow = calculate_fit_contain_output_area(*tensor.shape[-2:], *size) resized = _resize(tensor, (oh, ow), mode=resize_mode) result = tensor.new_full((*tensor.size()[:-2], *size), fill) result[..., y_off:y_off + oh, x_off:x_off + ow] = resized return result elif fit_mode == 'cover': ih, iw = tensor.shape[-2:] k = max(size[-1] / iw, size[-2] / ih) oh = round(k * ih) ow = round(k * iw) resized = _resize(tensor, (oh, ow), mode=resize_mode) y_trim = (oh - size[-2]) // 2 x_trim = (ow - size[-1]) // 2 result = _crop(resized, y_trim, x_trim, size[-2], size[-1]) return result raise ValueError('Invalid fit_mode: ' + repr(fit_mode))