Sirus1's picture
Duplicate from TencentARC/VLog
6f6830f
raw
history blame
No virus
4.14 kB
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
# Modified by Xingyi Zhou
# The original code is under Apache-2.0 License
import numpy as np
import torch
import torch.nn.functional as F
from fvcore.transforms.transform import (
CropTransform,
HFlipTransform,
NoOpTransform,
Transform,
TransformList,
)
from PIL import Image
try:
import cv2 # noqa
except ImportError:
# OpenCV is an optional dependency at the moment
pass
__all__ = [
"EfficientDetResizeCropTransform",
]
class EfficientDetResizeCropTransform(Transform):
"""
"""
def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, \
target_size, interp=None):
"""
Args:
h, w (int): original image size
new_h, new_w (int): new image size
interp: PIL interpolation methods, defaults to bilinear.
"""
# TODO decide on PIL vs opencv
super().__init__()
if interp is None:
interp = Image.BILINEAR
self._set_attributes(locals())
def apply_image(self, img, interp=None):
assert len(img.shape) <= 4
if img.dtype == np.uint8:
pil_image = Image.fromarray(img)
interp_method = interp if interp is not None else self.interp
pil_image = pil_image.resize((self.scaled_w, self.scaled_h), interp_method)
ret = np.asarray(pil_image)
right = min(self.scaled_w, self.offset_x + self.target_size[1])
lower = min(self.scaled_h, self.offset_y + self.target_size[0])
if len(ret.shape) <= 3:
ret = ret[self.offset_y: lower, self.offset_x: right]
else:
ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
else:
# PIL only supports uint8
img = torch.from_numpy(img)
shape = list(img.shape)
shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
_PIL_RESIZE_TO_INTERPOLATE_MODE = {Image.BILINEAR: "bilinear", Image.BICUBIC: "bicubic"}
mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp]
img = F.interpolate(img, (self.scaled_h, self.scaled_w), mode=mode, align_corners=False)
shape[:2] = (self.scaled_h, self.scaled_w)
ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
right = min(self.scaled_w, self.offset_x + self.target_size[1])
lower = min(self.scaled_h, self.offset_y + self.target_size[0])
if len(ret.shape) <= 3:
ret = ret[self.offset_y: lower, self.offset_x: right]
else:
ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
return ret
def apply_coords(self, coords):
coords[:, 0] = coords[:, 0] * self.img_scale
coords[:, 1] = coords[:, 1] * self.img_scale
coords[:, 0] -= self.offset_x
coords[:, 1] -= self.offset_y
return coords
def apply_segmentation(self, segmentation):
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
return segmentation
def inverse(self):
raise NotImplementedError
def inverse_apply_coords(self, coords):
coords[:, 0] += self.offset_x
coords[:, 1] += self.offset_y
coords[:, 0] = coords[:, 0] / self.img_scale
coords[:, 1] = coords[:, 1] / self.img_scale
return coords
def inverse_apply_box(self, box: np.ndarray) -> np.ndarray:
"""
"""
idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2)
coords = self.inverse_apply_coords(coords).reshape((-1, 4, 2))
minxy = coords.min(axis=1)
maxxy = coords.max(axis=1)
trans_boxes = np.concatenate((minxy, maxxy), axis=1)
return trans_boxes