File size: 2,677 Bytes
237de06 d8aafaa 237de06 d8aafaa 237de06 d8aafaa 237de06 d8aafaa 237de06 d8aafaa 237de06 d8aafaa 237de06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from PIL import Image
import numpy as np
import torch
from torchvision.transforms import functional as TF
class Compose:
"""Composes several transforms together."""
def __init__(self, transforms):
self.transforms = transforms
for transform in self.transforms:
if hasattr(transform, "set_parent"):
transform.set_parent(self)
def __call__(self, image, boxes):
for transform in self.transforms:
image, boxes = transform(image, boxes)
return image, boxes
def get_more_data(self):
raise NotImplementedError("This method should be overridden by subclass instances!")
class RandomHorizontalFlip:
"""Randomly horizontally flips the image along with the bounding boxes."""
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, image, boxes):
if torch.rand(1) < self.prob:
image = TF.hflip(image)
boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
return image, boxes
class Mosaic:
"""Applies the Mosaic augmentation to a batch of images and their corresponding boxes."""
def __init__(self, prob=0.5):
self.prob = prob
self.parent = None
def set_parent(self, parent):
self.parent = parent
def __call__(self, image, boxes):
if torch.rand(1) >= self.prob:
return image, boxes
assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
more_data = self.parent.get_more_data(3) # get 3 more images randomly
data = [(image, boxes)] + more_data
mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz))
vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)])
center = np.array([img_sz, img_sz])
all_labels = []
for (image, boxes), vector in zip(data, vectors):
this_w, this_h = image.size
coord = tuple(center + vector * np.array([this_w, this_h]))
mosaic_image.paste(image, coord)
xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
xmin = (xmin * this_w + coord[0]) / (2 * img_sz)
xmax = (xmax * this_w + coord[0]) / (2 * img_sz)
ymin = (ymin * this_h + coord[1]) / (2 * img_sz)
ymax = (ymax * this_h + coord[1]) / (2 * img_sz)
adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1)
all_labels.append(adjusted_boxes)
all_labels = torch.cat(all_labels, dim=0)
return mosaic_image, all_labels
|