File size: 6,273 Bytes
002bd9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Copy from: https://github.com/JialianW/GRiT/blob/39b33dbc0900e4be0458af14597fcb1a82d933bb/grit/data/transforms/custom_transform{_impl}.py
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
# Modified by Xingyi Zhou
# The original code is under Apache-2.0 License
import numpy as np
import torch
import torch.nn.functional as F
from fvcore.transforms.transform import (
CropTransform,
HFlipTransform,
NoOpTransform,
Transform,
TransformList,
)
from PIL import Image
try:
import cv2 # noqa
except ImportError:
# OpenCV is an optional dependency at the moment
pass
__all__ = [
"EfficientDetResizeCropTransform",
]
class EfficientDetResizeCropTransform(Transform):
""" """
def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, target_size, interp=None):
"""
Args:
h, w (int): original image size
new_h, new_w (int): new image size
interp: PIL interpolation methods, defaults to bilinear.
"""
# TODO decide on PIL vs opencv
super().__init__()
if interp is None:
interp = Image.BILINEAR
self._set_attributes(locals())
def apply_image(self, img, interp=None):
assert len(img.shape) <= 4
if img.dtype == np.uint8:
pil_image = Image.fromarray(img)
interp_method = interp if interp is not None else self.interp
pil_image = pil_image.resize((self.scaled_w, self.scaled_h), interp_method)
ret = np.asarray(pil_image)
right = min(self.scaled_w, self.offset_x + self.target_size[1])
lower = min(self.scaled_h, self.offset_y + self.target_size[0])
if len(ret.shape) <= 3:
ret = ret[self.offset_y : lower, self.offset_x : right]
else:
ret = ret[..., self.offset_y : lower, self.offset_x : right, :]
else:
# PIL only supports uint8
img = torch.from_numpy(img)
shape = list(img.shape)
shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
_PIL_RESIZE_TO_INTERPOLATE_MODE = {Image.BILINEAR: "bilinear", Image.BICUBIC: "bicubic"}
mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp]
img = F.interpolate(img, (self.scaled_h, self.scaled_w), mode=mode, align_corners=False)
shape[:2] = (self.scaled_h, self.scaled_w)
ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
right = min(self.scaled_w, self.offset_x + self.target_size[1])
lower = min(self.scaled_h, self.offset_y + self.target_size[0])
if len(ret.shape) <= 3:
ret = ret[self.offset_y : lower, self.offset_x : right]
else:
ret = ret[..., self.offset_y : lower, self.offset_x : right, :]
return ret
def apply_coords(self, coords):
coords[:, 0] = coords[:, 0] * self.img_scale
coords[:, 1] = coords[:, 1] * self.img_scale
coords[:, 0] -= self.offset_x
coords[:, 1] -= self.offset_y
return coords
def apply_segmentation(self, segmentation):
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
return segmentation
def inverse(self):
raise NotImplementedError
def inverse_apply_coords(self, coords):
coords[:, 0] += self.offset_x
coords[:, 1] += self.offset_y
coords[:, 0] = coords[:, 0] / self.img_scale
coords[:, 1] = coords[:, 1] / self.img_scale
return coords
def inverse_apply_box(self, box: np.ndarray) -> np.ndarray:
""" """
idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2)
coords = self.inverse_apply_coords(coords).reshape((-1, 4, 2))
minxy = coords.min(axis=1)
maxxy = coords.max(axis=1)
trans_boxes = np.concatenate((minxy, maxxy), axis=1)
return trans_boxes
# ANOTHER FILE
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
# Modified by Xingyi Zhou
# The original code is under Apache-2.0 License
import numpy as np
from PIL import Image
from transforms.augmentation import Augmentation
__all__ = [
"EfficientDetResizeCrop",
]
class EfficientDetResizeCrop(Augmentation):
"""
Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
"""
def __init__(self, size, scale, interp=Image.BILINEAR):
""" """
super().__init__()
self.target_size = (size, size)
self.scale = scale
self.interp = interp
def get_transform(self, img):
# Select a random scale factor.
scale_factor = np.random.uniform(*self.scale)
scaled_target_height = scale_factor * self.target_size[0]
scaled_target_width = scale_factor * self.target_size[1]
# Recompute the accurate scale_factor using rounded scaled image size.
width, height = img.shape[1], img.shape[0]
img_scale_y = scaled_target_height / height
img_scale_x = scaled_target_width / width
img_scale = min(img_scale_y, img_scale_x)
# Select non-zero random offset (x, y) if scaled image is larger than target size
scaled_h = int(height * img_scale)
scaled_w = int(width * img_scale)
offset_y = scaled_h - self.target_size[0]
offset_x = scaled_w - self.target_size[1]
offset_y = int(max(0.0, float(offset_y)) * np.random.uniform(0, 1))
offset_x = int(max(0.0, float(offset_x)) * np.random.uniform(0, 1))
return EfficientDetResizeCropTransform(
scaled_h, scaled_w, offset_y, offset_x, img_scale, self.target_size, self.interp
)
|