Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
# Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py | |
# Modified by Xingyi Zhou | |
# The original code is under Apache-2.0 License | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from fvcore.transforms.transform import ( | |
CropTransform, | |
HFlipTransform, | |
NoOpTransform, | |
Transform, | |
TransformList, | |
) | |
from PIL import Image | |
try: | |
import cv2 # noqa | |
except ImportError: | |
# OpenCV is an optional dependency at the moment | |
pass | |
__all__ = [ | |
"EfficientDetResizeCropTransform", | |
] | |
class EfficientDetResizeCropTransform(Transform): | |
""" | |
""" | |
def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, \ | |
target_size, interp=None): | |
""" | |
Args: | |
h, w (int): original image size | |
new_h, new_w (int): new image size | |
interp: PIL interpolation methods, defaults to bilinear. | |
""" | |
# TODO decide on PIL vs opencv | |
super().__init__() | |
if interp is None: | |
interp = Image.BILINEAR | |
self._set_attributes(locals()) | |
def apply_image(self, img, interp=None): | |
assert len(img.shape) <= 4 | |
if img.dtype == np.uint8: | |
pil_image = Image.fromarray(img) | |
interp_method = interp if interp is not None else self.interp | |
pil_image = pil_image.resize((self.scaled_w, self.scaled_h), interp_method) | |
ret = np.asarray(pil_image) | |
right = min(self.scaled_w, self.offset_x + self.target_size[1]) | |
lower = min(self.scaled_h, self.offset_y + self.target_size[0]) | |
if len(ret.shape) <= 3: | |
ret = ret[self.offset_y: lower, self.offset_x: right] | |
else: | |
ret = ret[..., self.offset_y: lower, self.offset_x: right, :] | |
else: | |
# PIL only supports uint8 | |
img = torch.from_numpy(img) | |
shape = list(img.shape) | |
shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:] | |
img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw | |
_PIL_RESIZE_TO_INTERPOLATE_MODE = {Image.BILINEAR: "bilinear", Image.BICUBIC: "bicubic"} | |
mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp] | |
img = F.interpolate(img, (self.scaled_h, self.scaled_w), mode=mode, align_corners=False) | |
shape[:2] = (self.scaled_h, self.scaled_w) | |
ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c) | |
right = min(self.scaled_w, self.offset_x + self.target_size[1]) | |
lower = min(self.scaled_h, self.offset_y + self.target_size[0]) | |
if len(ret.shape) <= 3: | |
ret = ret[self.offset_y: lower, self.offset_x: right] | |
else: | |
ret = ret[..., self.offset_y: lower, self.offset_x: right, :] | |
return ret | |
def apply_coords(self, coords): | |
coords[:, 0] = coords[:, 0] * self.img_scale | |
coords[:, 1] = coords[:, 1] * self.img_scale | |
coords[:, 0] -= self.offset_x | |
coords[:, 1] -= self.offset_y | |
return coords | |
def apply_segmentation(self, segmentation): | |
segmentation = self.apply_image(segmentation, interp=Image.NEAREST) | |
return segmentation | |
def inverse(self): | |
raise NotImplementedError | |
def inverse_apply_coords(self, coords): | |
coords[:, 0] += self.offset_x | |
coords[:, 1] += self.offset_y | |
coords[:, 0] = coords[:, 0] / self.img_scale | |
coords[:, 1] = coords[:, 1] / self.img_scale | |
return coords | |
def inverse_apply_box(self, box: np.ndarray) -> np.ndarray: | |
""" | |
""" | |
idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten() | |
coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2) | |
coords = self.inverse_apply_coords(coords).reshape((-1, 4, 2)) | |
minxy = coords.min(axis=1) | |
maxxy = coords.max(axis=1) | |
trans_boxes = np.concatenate((minxy, maxxy), axis=1) | |
return trans_boxes |