Spaces:
Build error
Build error
File size: 4,511 Bytes
1ed7deb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import random
import warnings
from typing import Union
import torch
from torch import Tensor
from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
from torchvision.transforms.functional import _get_image_size as get_image_size
from taming.data.helper_types import BoundingBox, Image
pil_to_tensor = PILToTensor()
def convert_pil_to_tensor(image: Image) -> Tensor:
with warnings.catch_warnings():
# to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
warnings.simplefilter("ignore")
return pil_to_tensor(image)
class RandomCrop1dReturnCoordinates(RandomCrop):
def forward(self, img: Image) -> (BoundingBox, Image):
"""
Additionally to cropping, returns the relative coordinates of the crop bounding box.
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
Bounding box: x0, y0, w, h
PIL Image or Tensor: Cropped image.
Based on:
torchvision.transforms.RandomCrop, torchvision 1.7.0
"""
if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode)
width, height = get_image_size(img)
# pad the width if needed
if self.pad_if_needed and width < self.size[1]:
padding = [self.size[1] - width, 0]
img = F.pad(img, padding, self.fill, self.padding_mode)
# pad the height if needed
if self.pad_if_needed and height < self.size[0]:
padding = [0, self.size[0] - height]
img = F.pad(img, padding, self.fill, self.padding_mode)
i, j, h, w = self.get_params(img, self.size)
bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
return bbox, F.crop(img, i, j, h, w)
class Random2dCropReturnCoordinates(torch.nn.Module):
"""
Additionally to cropping, returns the relative coordinates of the crop bounding box.
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
Bounding box: x0, y0, w, h
PIL Image or Tensor: Cropped image.
Based on:
torchvision.transforms.RandomCrop, torchvision 1.7.0
"""
def __init__(self, min_size: int):
super().__init__()
self.min_size = min_size
def forward(self, img: Image) -> (BoundingBox, Image):
width, height = get_image_size(img)
max_size = min(width, height)
if max_size <= self.min_size:
size = max_size
else:
size = random.randint(self.min_size, max_size)
top = random.randint(0, height - size)
left = random.randint(0, width - size)
bbox = left / width, top / height, size / width, size / height
return bbox, F.crop(img, top, left, size, size)
class CenterCropReturnCoordinates(CenterCrop):
@staticmethod
def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
if width > height:
w = height / width
h = 1.0
x0 = 0.5 - w / 2
y0 = 0.
else:
w = 1.0
h = width / height
x0 = 0.
y0 = 0.5 - h / 2
return x0, y0, w, h
def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
"""
Additionally to cropping, returns the relative coordinates of the crop bounding box.
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
Bounding box: x0, y0, w, h
PIL Image or Tensor: Cropped image.
Based on:
torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
"""
width, height = get_image_size(img)
return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
class RandomHorizontalFlipReturn(RandomHorizontalFlip):
def forward(self, img: Image) -> (bool, Image):
"""
Additionally to flipping, returns a boolean whether it was flipped or not.
Args:
img (PIL Image or Tensor): Image to be flipped.
Returns:
flipped: whether the image was flipped or not
PIL Image or Tensor: Randomly flipped image.
Based on:
torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
"""
if torch.rand(1) < self.p:
return True, F.hflip(img)
return False, img
|