Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
from typing import Optional
import torch
# TODO: implement width of the line
def draw_rectangle(
image: torch.Tensor,
rectangle: torch.Tensor,
color: Optional[torch.Tensor] = None,
fill: Optional[bool] = None,
) -> torch.Tensor:
r"""Draw N rectangles on a batch of image tensors.
Args:
image: is tensor of BxCxHxW.
rectangle: represents number of rectangles to draw in BxNx4
N is the number of boxes to draw per batch index[x1, y1, x2, y2]
4 is in (top_left.x, top_left.y, bot_right.x, bot_right.y).
color: a size 1, size 3, BxNx1, or BxNx3 tensor.
If C is 3, and color is 1 channel it will be broadcasted.
fill: is a flag used to fill the boxes with color if True.
Returns:
This operation modifies image inplace but also returns the drawn tensor for
convenience with same shape the of the input BxCxHxW.
Example:
>>> img = torch.rand(2, 3, 10, 12)
>>> rect = torch.tensor([[[0, 0, 4, 4]], [[4, 4, 10, 10]]])
>>> out = draw_rectangle(img, rect)
"""
batch, c, h, w = image.shape
batch_rect, num_rectangle, num_points = rectangle.shape
if batch != batch_rect:
raise AssertionError("Image batch and rectangle batch must be equal")
if num_points != 4:
raise AssertionError("Number of points in rectangle must be 4")
# clone rectangle, in case it's been expanded assignment from clipping causes problems
rectangle = rectangle.long().clone()
# clip rectangle to hxw bounds
rectangle[:, :, 1::2] = torch.clamp(rectangle[:, :, 1::2], 0, h - 1)
rectangle[:, :, ::2] = torch.clamp(rectangle[:, :, ::2], 0, w - 1)
if color is None:
color = torch.tensor([0.0] * c).expand(batch, num_rectangle, c)
if fill is None:
fill = False
if len(color.shape) == 1:
color = color.expand(batch, num_rectangle, c)
b, n, color_channels = color.shape
if color_channels == 1 and c == 3:
color = color.expand(batch, num_rectangle, c)
for b in range(batch):
for n in range(num_rectangle):
if fill:
image[
b,
:,
int(rectangle[b, n, 1]): int(rectangle[b, n, 3] + 1),
int(rectangle[b, n, 0]): int(rectangle[b, n, 2] + 1),
] = color[b, n, :, None, None]
else:
image[b, :, int(rectangle[b, n, 1]): int(rectangle[b, n, 3] + 1), rectangle[b, n, 0]] = color[
b, n, :, None
]
image[b, :, int(rectangle[b, n, 1]): int(rectangle[b, n, 3] + 1), rectangle[b, n, 2]] = color[
b, n, :, None
]
image[b, :, rectangle[b, n, 1], int(rectangle[b, n, 0]): int(rectangle[b, n, 2] + 1)] = color[
b, n, :, None
]
image[b, :, rectangle[b, n, 3], int(rectangle[b, n, 0]): int(rectangle[b, n, 2] + 1)] = color[
b, n, :, None
]
return image