|
from typing import Optional |
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
def draw_rectangle( |
|
image: torch.Tensor, |
|
rectangle: torch.Tensor, |
|
color: Optional[torch.Tensor] = None, |
|
fill: Optional[bool] = None, |
|
) -> torch.Tensor: |
|
r"""Draw N rectangles on a batch of image tensors. |
|
|
|
Args: |
|
image: is tensor of BxCxHxW. |
|
rectangle: represents number of rectangles to draw in BxNx4 |
|
N is the number of boxes to draw per batch index[x1, y1, x2, y2] |
|
4 is in (top_left.x, top_left.y, bot_right.x, bot_right.y). |
|
color: a size 1, size 3, BxNx1, or BxNx3 tensor. |
|
If C is 3, and color is 1 channel it will be broadcasted. |
|
fill: is a flag used to fill the boxes with color if True. |
|
|
|
Returns: |
|
This operation modifies image inplace but also returns the drawn tensor for |
|
convenience with same shape the of the input BxCxHxW. |
|
|
|
Example: |
|
>>> img = torch.rand(2, 3, 10, 12) |
|
>>> rect = torch.tensor([[[0, 0, 4, 4]], [[4, 4, 10, 10]]]) |
|
>>> out = draw_rectangle(img, rect) |
|
""" |
|
|
|
batch, c, h, w = image.shape |
|
batch_rect, num_rectangle, num_points = rectangle.shape |
|
if batch != batch_rect: |
|
raise AssertionError("Image batch and rectangle batch must be equal") |
|
if num_points != 4: |
|
raise AssertionError("Number of points in rectangle must be 4") |
|
|
|
|
|
rectangle = rectangle.long().clone() |
|
|
|
|
|
rectangle[:, :, 1::2] = torch.clamp(rectangle[:, :, 1::2], 0, h - 1) |
|
rectangle[:, :, ::2] = torch.clamp(rectangle[:, :, ::2], 0, w - 1) |
|
|
|
if color is None: |
|
color = torch.tensor([0.0] * c).expand(batch, num_rectangle, c) |
|
|
|
if fill is None: |
|
fill = False |
|
|
|
if len(color.shape) == 1: |
|
color = color.expand(batch, num_rectangle, c) |
|
|
|
b, n, color_channels = color.shape |
|
|
|
if color_channels == 1 and c == 3: |
|
color = color.expand(batch, num_rectangle, c) |
|
|
|
for b in range(batch): |
|
for n in range(num_rectangle): |
|
if fill: |
|
image[ |
|
b, |
|
:, |
|
int(rectangle[b, n, 1]): int(rectangle[b, n, 3] + 1), |
|
int(rectangle[b, n, 0]): int(rectangle[b, n, 2] + 1), |
|
] = color[b, n, :, None, None] |
|
else: |
|
image[b, :, int(rectangle[b, n, 1]): int(rectangle[b, n, 3] + 1), rectangle[b, n, 0]] = color[ |
|
b, n, :, None |
|
] |
|
image[b, :, int(rectangle[b, n, 1]): int(rectangle[b, n, 3] + 1), rectangle[b, n, 2]] = color[ |
|
b, n, :, None |
|
] |
|
image[b, :, rectangle[b, n, 1], int(rectangle[b, n, 0]): int(rectangle[b, n, 2] + 1)] = color[ |
|
b, n, :, None |
|
] |
|
image[b, :, rectangle[b, n, 3], int(rectangle[b, n, 0]): int(rectangle[b, n, 2] + 1)] = color[ |
|
b, n, :, None |
|
] |
|
|
|
return image |
|
|