Spaces:
Sleeping
Sleeping
import math | |
import pathlib | |
import warnings | |
from types import FunctionType | |
from typing import Any, BinaryIO, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from PIL import Image, ImageColor, ImageDraw, ImageFont | |
__all__ = [ | |
"make_grid", | |
"save_image", | |
"draw_bounding_boxes", | |
"draw_segmentation_masks", | |
"draw_keypoints", | |
"flow_to_image", | |
] | |
def make_grid( | |
tensor: Union[torch.Tensor, List[torch.Tensor]], | |
nrow: int = 8, | |
padding: int = 2, | |
normalize: bool = False, | |
value_range: Optional[Tuple[int, int]] = None, | |
scale_each: bool = False, | |
pad_value: float = 0.0, | |
**kwargs, | |
) -> torch.Tensor: | |
""" | |
Make a grid of images. | |
Args: | |
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) | |
or a list of images all of the same size. | |
nrow (int, optional): Number of images displayed in each row of the grid. | |
The final grid size is ``(B / nrow, nrow)``. Default: ``8``. | |
padding (int, optional): amount of padding. Default: ``2``. | |
normalize (bool, optional): If True, shift the image to the range (0, 1), | |
by the min and max values specified by ``value_range``. Default: ``False``. | |
value_range (tuple, optional): tuple (min, max) where min and max are numbers, | |
then these numbers are used to normalize the image. By default, min and max | |
are computed from the tensor. | |
range (tuple. optional): | |
.. warning:: | |
This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range`` | |
instead. | |
scale_each (bool, optional): If ``True``, scale each image in the batch of | |
images separately rather than the (min, max) over all images. Default: ``False``. | |
pad_value (float, optional): Value for the padded pixels. Default: ``0``. | |
Returns: | |
grid (Tensor): the tensor containing grid of images. | |
""" | |
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
_log_api_usage_once(make_grid) | |
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): | |
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") | |
if "range" in kwargs.keys(): | |
warnings.warn( | |
"The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. " | |
"Please use 'value_range' instead." | |
) | |
value_range = kwargs["range"] | |
# if list of tensors, convert to a 4D mini-batch Tensor | |
if isinstance(tensor, list): | |
tensor = torch.stack(tensor, dim=0) | |
if tensor.dim() == 2: # single image H x W | |
tensor = tensor.unsqueeze(0) | |
if tensor.dim() == 3: # single image | |
if tensor.size(0) == 1: # if single-channel, convert to 3-channel | |
tensor = torch.cat((tensor, tensor, tensor), 0) | |
tensor = tensor.unsqueeze(0) | |
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images | |
tensor = torch.cat((tensor, tensor, tensor), 1) | |
if normalize is True: | |
tensor = tensor.clone() # avoid modifying tensor in-place | |
if value_range is not None: | |
assert isinstance( | |
value_range, tuple | |
), "value_range has to be a tuple (min, max) if specified. min and max are numbers" | |
def norm_ip(img, low, high): | |
img.clamp_(min=low, max=high) | |
img.sub_(low).div_(max(high - low, 1e-5)) | |
def norm_range(t, value_range): | |
if value_range is not None: | |
norm_ip(t, value_range[0], value_range[1]) | |
else: | |
norm_ip(t, float(t.min()), float(t.max())) | |
if scale_each is True: | |
for t in tensor: # loop over mini-batch dimension | |
norm_range(t, value_range) | |
else: | |
norm_range(tensor, value_range) | |
assert isinstance(tensor, torch.Tensor) | |
if tensor.size(0) == 1: | |
return tensor.squeeze(0) | |
# make the mini-batch of images into a grid | |
nmaps = tensor.size(0) | |
xmaps = min(nrow, nmaps) | |
ymaps = int(math.ceil(float(nmaps) / xmaps)) | |
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) | |
num_channels = tensor.size(1) | |
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) | |
k = 0 | |
for y in range(ymaps): | |
for x in range(xmaps): | |
if k >= nmaps: | |
break | |
# Tensor.copy_() is a valid method but seems to be missing from the stubs | |
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ | |
grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] | |
2, x * width + padding, width - padding | |
).copy_(tensor[k]) | |
k = k + 1 | |
return grid | |
def save_image( | |
tensor: Union[torch.Tensor, List[torch.Tensor]], | |
fp: Union[str, pathlib.Path, BinaryIO], | |
format: Optional[str] = None, | |
**kwargs, | |
) -> None: | |
""" | |
Save a given Tensor into an image file. | |
Args: | |
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, | |
saves the tensor as a grid of images by calling ``make_grid``. | |
fp (string or file object): A filename or a file object | |
format(Optional): If omitted, the format to use is determined from the filename extension. | |
If a file object was used instead of a filename, this parameter should always be used. | |
**kwargs: Other arguments are documented in ``make_grid``. | |
""" | |
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
_log_api_usage_once(save_image) | |
grid = make_grid(tensor, **kwargs) | |
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer | |
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
im = Image.fromarray(ndarr) | |
im.save(fp, format=format) | |
def draw_bounding_boxes( | |
image: torch.Tensor, | |
boxes: torch.Tensor, | |
labels: Optional[List[str]] = None, | |
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, | |
fill: Optional[bool] = False, | |
width: int = 1, | |
font: Optional[str] = None, | |
font_size: int = 10, | |
) -> torch.Tensor: | |
""" | |
Draws bounding boxes on given image. | |
The values of the input image should be uint8 between 0 and 255. | |
If fill is True, Resulting Tensor should be saved as PNG image. | |
Args: | |
image (Tensor): Tensor of shape (C x H x W) and dtype uint8. | |
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that | |
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and | |
`0 <= ymin < ymax < H`. | |
labels (List[str]): List containing the labels of bounding boxes. | |
colors (color or list of colors, optional): List containing the colors | |
of the boxes or single color for all boxes. The color can be represented as | |
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. | |
By default, random colors are generated for boxes. | |
fill (bool): If `True` fills the bounding box with specified color. | |
width (int): Width of bounding box. | |
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may | |
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, | |
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. | |
font_size (int): The requested font size in points. | |
Returns: | |
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. | |
""" | |
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
_log_api_usage_once(draw_bounding_boxes) | |
if not isinstance(image, torch.Tensor): | |
raise TypeError(f"Tensor expected, got {type(image)}") | |
elif image.dtype != torch.uint8: | |
raise ValueError(f"Tensor uint8 expected, got {image.dtype}") | |
elif image.dim() != 3: | |
raise ValueError("Pass individual images, not batches") | |
elif image.size(0) not in {1, 3}: | |
raise ValueError("Only grayscale and RGB images are supported") | |
num_boxes = boxes.shape[0] | |
if labels is None: | |
labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] | |
elif len(labels) != num_boxes: | |
raise ValueError( | |
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." | |
) | |
if colors is None: | |
colors = _generate_color_palette(num_boxes) | |
elif isinstance(colors, list): | |
if len(colors) < num_boxes: | |
raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ") | |
else: # colors specifies a single color for all boxes | |
colors = [colors] * num_boxes | |
colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] | |
# Handle Grayscale images | |
if image.size(0) == 1: | |
image = torch.tile(image, (3, 1, 1)) | |
ndarr = image.permute(1, 2, 0).cpu().numpy() | |
img_to_draw = Image.fromarray(ndarr) | |
img_boxes = boxes.to(torch.int64).tolist() | |
if fill: | |
draw = ImageDraw.Draw(img_to_draw, "RGBA") | |
else: | |
draw = ImageDraw.Draw(img_to_draw) | |
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) | |
for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] | |
if fill: | |
fill_color = color + (100,) | |
draw.rectangle(bbox, width=width, outline=color, fill=fill_color) | |
else: | |
draw.rectangle(bbox, width=width, outline=color) | |
if label is not None: | |
margin = width + 1 | |
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) | |
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) | |
def draw_segmentation_masks( | |
image: torch.Tensor, | |
masks: torch.Tensor, | |
alpha: float = 0.8, | |
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, | |
) -> torch.Tensor: | |
""" | |
Draws segmentation masks on given RGB image. | |
The values of the input image should be uint8 between 0 and 255. | |
Args: | |
image (Tensor): Tensor of shape (3, H, W) and dtype uint8. | |
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. | |
alpha (float): Float number between 0 and 1 denoting the transparency of the masks. | |
0 means full transparency, 1 means no transparency. | |
colors (color or list of colors, optional): List containing the colors | |
of the masks or single color for all masks. The color can be represented as | |
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. | |
By default, random colors are generated for each mask. | |
Returns: | |
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. | |
""" | |
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
_log_api_usage_once(draw_segmentation_masks) | |
if not isinstance(image, torch.Tensor): | |
raise TypeError(f"The image must be a tensor, got {type(image)}") | |
elif image.dtype != torch.uint8: | |
raise ValueError(f"The image dtype must be uint8, got {image.dtype}") | |
elif image.dim() != 3: | |
raise ValueError("Pass individual images, not batches") | |
elif image.size()[0] != 3: | |
raise ValueError("Pass an RGB image. Other Image formats are not supported") | |
if masks.ndim == 2: | |
masks = masks[None, :, :] | |
if masks.ndim != 3: | |
raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") | |
if masks.dtype != torch.bool: | |
raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") | |
if masks.shape[-2:] != image.shape[-2:]: | |
raise ValueError("The image and the masks must have the same height and width") | |
num_masks = masks.size()[0] | |
if colors is not None and num_masks > len(colors): | |
raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") | |
if colors is None: | |
colors = _generate_color_palette(num_masks) | |
if not isinstance(colors, list): | |
colors = [colors] | |
if not isinstance(colors[0], (tuple, str)): | |
raise ValueError("colors must be a tuple or a string, or a list thereof") | |
if isinstance(colors[0], tuple) and len(colors[0]) != 3: | |
raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") | |
out_dtype = torch.uint8 | |
colors_ = [] | |
for color in colors: | |
if isinstance(color, str): | |
color = ImageColor.getrgb(color) | |
colors_.append(torch.tensor(color, dtype=out_dtype)) | |
img_to_draw = image.detach().clone() | |
# TODO: There might be a way to vectorize this | |
for mask, color in zip(masks, colors_): | |
img_to_draw[:, mask] = color[:, None] | |
out = image * (1 - alpha) + img_to_draw * alpha | |
return out.to(out_dtype) | |
def draw_keypoints( | |
image: torch.Tensor, | |
keypoints: torch.Tensor, | |
connectivity: Optional[List[Tuple[int, int]]] = None, | |
colors: Optional[Union[str, Tuple[int, int, int]]] = None, | |
radius: int = 2, | |
width: int = 3, | |
) -> torch.Tensor: | |
""" | |
Draws Keypoints on given RGB image. | |
The values of the input image should be uint8 between 0 and 255. | |
Args: | |
image (Tensor): Tensor of shape (3, H, W) and dtype uint8. | |
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, | |
in the format [x, y]. | |
connectivity (List[Tuple[int, int]]]): A List of tuple where, | |
each tuple contains pair of keypoints to be connected. | |
colors (str, Tuple): The color can be represented as | |
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. | |
radius (int): Integer denoting radius of keypoint. | |
width (int): Integer denoting width of line connecting keypoints. | |
Returns: | |
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. | |
""" | |
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
_log_api_usage_once(draw_keypoints) | |
if not isinstance(image, torch.Tensor): | |
raise TypeError(f"The image must be a tensor, got {type(image)}") | |
elif image.dtype != torch.uint8: | |
raise ValueError(f"The image dtype must be uint8, got {image.dtype}") | |
elif image.dim() != 3: | |
raise ValueError("Pass individual images, not batches") | |
elif image.size()[0] != 3: | |
raise ValueError("Pass an RGB image. Other Image formats are not supported") | |
if keypoints.ndim != 3: | |
raise ValueError("keypoints must be of shape (num_instances, K, 2)") | |
ndarr = image.permute(1, 2, 0).cpu().numpy() | |
img_to_draw = Image.fromarray(ndarr) | |
draw = ImageDraw.Draw(img_to_draw) | |
img_kpts = keypoints.to(torch.int64).tolist() | |
for kpt_id, kpt_inst in enumerate(img_kpts): | |
for inst_id, kpt in enumerate(kpt_inst): | |
x1 = kpt[0] - radius | |
x2 = kpt[0] + radius | |
y1 = kpt[1] - radius | |
y2 = kpt[1] + radius | |
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) | |
if connectivity: | |
for connection in connectivity: | |
start_pt_x = kpt_inst[connection[0]][0] | |
start_pt_y = kpt_inst[connection[0]][1] | |
end_pt_x = kpt_inst[connection[1]][0] | |
end_pt_y = kpt_inst[connection[1]][1] | |
draw.line( | |
((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), | |
width=width, | |
) | |
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) | |
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization | |
def flow_to_image(flow: torch.Tensor) -> torch.Tensor: | |
""" | |
Converts a flow to an RGB image. | |
Args: | |
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. | |
Returns: | |
img (Tensor): Image Tensor of dtype uint8 where each color corresponds | |
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. | |
""" | |
if flow.dtype != torch.float: | |
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") | |
orig_shape = flow.shape | |
if flow.ndim == 3: | |
flow = flow[None] # Add batch dim | |
if flow.ndim != 4 or flow.shape[1] != 2: | |
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") | |
max_norm = torch.sum(flow ** 2, dim=1).sqrt().max() | |
epsilon = torch.finfo((flow).dtype).eps | |
normalized_flow = flow / (max_norm + epsilon) | |
img = _normalized_flow_to_image(normalized_flow) | |
if len(orig_shape) == 3: | |
img = img[0] # Remove batch dim | |
return img | |
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: | |
""" | |
Converts a batch of normalized flow to an RGB image. | |
Args: | |
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) | |
Returns: | |
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. | |
""" | |
N, _, H, W = normalized_flow.shape | |
device = normalized_flow.device | |
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) | |
colorwheel = _make_colorwheel().to(device) # shape [55x3] | |
num_cols = colorwheel.shape[0] | |
norm = torch.sum(normalized_flow ** 2, dim=1).sqrt() | |
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi | |
fk = (a + 1) / 2 * (num_cols - 1) | |
k0 = torch.floor(fk).to(torch.long) | |
k1 = k0 + 1 | |
k1[k1 == num_cols] = 0 | |
f = fk - k0 | |
for c in range(colorwheel.shape[1]): | |
tmp = colorwheel[:, c] | |
col0 = tmp[k0] / 255.0 | |
col1 = tmp[k1] / 255.0 | |
col = (1 - f) * col0 + f * col1 | |
col = 1 - norm * (1 - col) | |
flow_image[:, c, :, :] = torch.floor(255 * col) | |
return flow_image | |
def _make_colorwheel() -> torch.Tensor: | |
""" | |
Generates a color wheel for optical flow visualization as presented in: | |
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) | |
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. | |
Returns: | |
colorwheel (Tensor[55, 3]): Colorwheel Tensor. | |
""" | |
RY = 15 | |
YG = 6 | |
GC = 4 | |
CB = 11 | |
BM = 13 | |
MR = 6 | |
ncols = RY + YG + GC + CB + BM + MR | |
colorwheel = torch.zeros((ncols, 3)) | |
col = 0 | |
# RY | |
colorwheel[0:RY, 0] = 255 | |
colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) | |
col = col + RY | |
# YG | |
colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) | |
colorwheel[col : col + YG, 1] = 255 | |
col = col + YG | |
# GC | |
colorwheel[col : col + GC, 1] = 255 | |
colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) | |
col = col + GC | |
# CB | |
colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) | |
colorwheel[col : col + CB, 2] = 255 | |
col = col + CB | |
# BM | |
colorwheel[col : col + BM, 2] = 255 | |
colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) | |
col = col + BM | |
# MR | |
colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) | |
colorwheel[col : col + MR, 0] = 255 | |
return colorwheel | |
def _generate_color_palette(num_objects: int): | |
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) | |
return [tuple((i * palette) % 255) for i in range(num_objects)] | |
def _log_api_usage_once(obj: Any) -> None: | |
""" | |
Logs API usage(module and name) within an organization. | |
In a large ecosystem, it's often useful to track the PyTorch and | |
TorchVision APIs usage. This API provides the similar functionality to the | |
logging module in the Python stdlib. It can be used for debugging purpose | |
to log which methods are used and by default it is inactive, unless the user | |
manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_. | |
Please note it is triggered only once for the same API call within a process. | |
It does not collect any data from open-source users since it is no-op by default. | |
For more information, please refer to | |
* PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; | |
* Logging policy: https://github.com/pytorch/vision/issues/5052; | |
Args: | |
obj (class instance or method): an object to extract info from. | |
""" | |
if not obj.__module__.startswith("torchvision"): | |
return | |
name = obj.__class__.__name__ | |
if isinstance(obj, FunctionType): | |
name = obj.__name__ | |
torch._C._log_api_usage_once(f"{obj.__module__}.{name}") | |