SupermanxKiaski's picture
Upload 356 files
16d007c
raw
history blame contribute delete
No virus
5.01 kB
import warnings
from collections.abc import Sequence
import numbers
import torchvision.transforms as T
from torchvision.transforms.functional import (
InterpolationMode,
_interpolation_modes_from_int,
get_image_num_channels,
get_image_size,
perspective,
crop,
)
import torch
import numpy as np
class RandomScale(object):
def __init__(self, scale_range=(0.8, 1.2), min_size=None):
super(RandomScale, self).__init__()
self.scale_range = scale_range
self.min_size = min_size if min_size is not None else 0
def __call__(self, img):
if isinstance(img, torch.Tensor):
height, width = img.shape[-2:]
else:
width, height = img.size[-2:]
s = np.random.uniform(*self.scale_range)
resize_h = max(int(height * s), self.min_size)
resize_w = max(int(width * s), self.min_size)
size = (resize_h, resize_w)
return T.Resize(size)(img)
class RandomSizeCrop(object):
def __init__(self, min_cover):
super(RandomSizeCrop, self).__init__()
self.min_cover = min_cover
def __call__(self, img):
if self.min_cover == 1:
return img
if isinstance(img, torch.Tensor):
h, w = img.shape[-2:]
else:
w, h = img.size[-2:]
s = np.random.uniform(self.min_cover, 1)
size_h = int(h * s)
size_w = int(w * s)
return T.RandomCrop((size_h, size_w))(img)
class DivisibleCrop(object):
def __init__(self, d):
super(DivisibleCrop, self).__init__()
self.d = d
def __call__(self, img):
if isinstance(img, torch.Tensor):
h, w = img.shape[-2:]
else:
w, h = img.size[-2:]
h = h - h % self.d
w = w - w % self.d
return T.CenterCrop((h, w))(img)
class ToTensorSafe(object):
def __init__(self):
super(ToTensorSafe, self).__init__()
def __call__(self, img):
if isinstance(img, torch.Tensor):
return img
return T.ToTensor()(img)
class BorderlessRandomPerspective(object):
"""Applies random perspective and crops the image to be without borders
Args:
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
Default is 0.5.
p (float): probability of the image being transformed. Default is 0.5.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (sequence or number): Pixel fill value for the area outside the transformed
image. Default is ``0``. If given a number, the value is used for all bands respectively.
"""
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
super().__init__()
self.p = p
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
self.distortion_scale = distortion_scale
if fill is None:
fill = 0
elif not isinstance(fill, (Sequence, numbers.Number)):
raise TypeError("Fill should be either a sequence or a number.")
self.fill = fill
@staticmethod
def get_crop_endpoints(endpoints):
topleft, topright, botright, botleft = endpoints
topy = max(topleft[1], topright[1])
leftx = max(topleft[0], botleft[0])
boty = min(botleft[1], botright[1])
rightx = min(topright[0], botright[0])
h = boty - topy
w = rightx - leftx
return topy, leftx, h, w
def __call__(self, img):
fill = self.fill
if isinstance(img, torch.Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
if torch.rand(1) < self.p:
width, height = get_image_size(img)
startpoints, endpoints = T.RandomPerspective.get_params(width, height, self.distortion_scale)
warped = perspective(img, startpoints, endpoints, self.interpolation, fill)
i, j, h, w = self.get_crop_endpoints(endpoints)
# print(f"Crop size: {h,w}")
cropped = crop(warped, i, j, h, w)
return T.Compose([T.Resize(224), T.CenterCrop(224)])(cropped)
return img