# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Note: This file has been barrowed from facebookresearch/slowfast repo. And it is used to add the bounding boxes and predictions to the frame. # TODO: Migrate this into the core PyTorchVideo libarary. from __future__ import annotations import itertools # import logging from types import SimpleNamespace from typing import Dict, List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np import torch from detectron2.utils.visualizer import Visualizer # logger = logging.getLogger(__name__) def _create_text_labels( classes: List[int], scores: List[float], class_names: List[str], ground_truth: bool = False, ) -> List[str]: """ Create text labels. Args: classes (list[int]): a list of class ids for each example. scores (list[float] or None): list of scores for each example. class_names (list[str]): a list of class names, ordered by their ids. ground_truth (bool): whether the labels are ground truth. Returns: labels (list[str]): formatted text labels. """ try: labels = [class_names.get(c, "n/a") for c in classes] except IndexError: # logger.error("Class indices get out of range: {}".format(classes)) return None if ground_truth: labels = ["[{}] {}".format("GT", label) for label in labels] elif scores is not None: assert len(classes) == len(scores) labels = ["[{:.2f}] {}".format(s, label) for s, label in zip(scores, labels)] return labels class ImgVisualizer(Visualizer): def __init__( self, img_rgb: torch.Tensor, meta: Optional[SimpleNamespace] = None, **kwargs ) -> None: """ See https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py for more details. Args: img_rgb: a tensor or numpy array of shape (H, W, C), where H and W correspond to the height and width of the image respectively. C is the number of color channels. The image is required to be in RGB format since that is a requirement of the Matplotlib library. The image is also expected to be in the range [0, 255]. meta (MetadataCatalog): image metadata. See https://github.com/facebookresearch/detectron2/blob/81d5a87763bfc71a492b5be89b74179bd7492f6b/detectron2/data/catalog.py#L90 """ super(ImgVisualizer, self).__init__(img_rgb, meta, **kwargs) def draw_text( self, text: str, position: List[int], *, font_size: Optional[int] = None, color: str = "w", horizontal_alignment: str = "center", vertical_alignment: str = "bottom", box_facecolor: str = "black", alpha: float = 0.5, ) -> None: """ Draw text at the specified position. Args: text (str): the text to draw on image. position (list of 2 ints): the x,y coordinate to place the text. font_size (Optional[int]): font of the text. If not provided, a font size proportional to the image width is calculated and used. color (str): color of the text. Refer to `matplotlib.colors` for full list of formats that are accepted. horizontal_alignment (str): see `matplotlib.text.Text`. vertical_alignment (str): see `matplotlib.text.Text`. box_facecolor (str): color of the box wrapped around the text. Refer to `matplotlib.colors` for full list of formats that are accepted. alpha (float): transparency level of the box. """ if not font_size: font_size = self._default_font_size x, y = position self.output.ax.text( x, y, text, size=font_size * self.output.scale, family="monospace", bbox={ "facecolor": box_facecolor, "alpha": alpha, "pad": 0.7, "edgecolor": "none", }, verticalalignment=vertical_alignment, horizontalalignment=horizontal_alignment, color=color, zorder=10, ) def draw_multiple_text( self, text_ls: List[str], box_coordinate: torch.Tensor, *, top_corner: bool = True, font_size: Optional[int] = None, color: str = "w", box_facecolors: str = "black", alpha: float = 0.5, ) -> None: """ Draw a list of text labels for some bounding box on the image. Args: text_ls (list of strings): a list of text labels. box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) coordinates of the box. top_corner (bool): If True, draw the text labels at (x_left, y_top) of the box. Else, draw labels at (x_left, y_bottom). font_size (Optional[int]): font of the text. If not provided, a font size proportional to the image width is calculated and used. color (str): color of the text. Refer to `matplotlib.colors` for full list of formats that are accepted. box_facecolors (str): colors of the box wrapped around the text. Refer to `matplotlib.colors` for full list of formats that are accepted. alpha (float): transparency level of the box. """ if not isinstance(box_facecolors, list): box_facecolors = [box_facecolors] * len(text_ls) assert len(box_facecolors) == len( text_ls ), "Number of colors provided is not equal to the number of text labels." if not font_size: font_size = self._default_font_size text_box_width = font_size + font_size // 2 # If the texts does not fit in the assigned location, # we split the text and draw it in another place. if top_corner: num_text_split = self._align_y_top( box_coordinate, len(text_ls), text_box_width ) y_corner = 1 else: num_text_split = len(text_ls) - self._align_y_bottom( box_coordinate, len(text_ls), text_box_width ) y_corner = 3 text_color_sorted = sorted( zip(text_ls, box_facecolors), key=lambda x: x[0], reverse=True ) if len(text_color_sorted) != 0: text_ls, box_facecolors = zip(*text_color_sorted) else: text_ls, box_facecolors = [], [] text_ls, box_facecolors = list(text_ls), list(box_facecolors) self.draw_multiple_text_upward( text_ls[:num_text_split][::-1], box_coordinate, y_corner=y_corner, font_size=font_size, color=color, box_facecolors=box_facecolors[:num_text_split][::-1], alpha=alpha, ) self.draw_multiple_text_downward( text_ls[num_text_split:], box_coordinate, y_corner=y_corner, font_size=font_size, color=color, box_facecolors=box_facecolors[num_text_split:], alpha=alpha, ) def draw_multiple_text_upward( self, text_ls: List[str], box_coordinate: torch.Tensor, *, y_corner: int = 1, font_size: Optional[int] = None, color: str = "w", box_facecolors: str = "black", alpha: float = 0.5, ) -> None: """ Draw a list of text labels for some bounding box on the image in upward direction. The next text label will be on top of the previous one. Args: text_ls (list of strings): a list of text labels. box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) coordinates of the box. y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of the box to draw labels around. font_size (Optional[int]): font of the text. If not provided, a font size proportional to the image width is calculated and used. color (str): color of the text. Refer to `matplotlib.colors` for full list of formats that are accepted. box_facecolors (str or list of strs): colors of the box wrapped around the text. Refer to `matplotlib.colors` for full list of formats that are accepted. alpha (float): transparency level of the box. """ if not isinstance(box_facecolors, list): box_facecolors = [box_facecolors] * len(text_ls) assert len(box_facecolors) == len( text_ls ), "Number of colors provided is not equal to the number of text labels." assert y_corner in [1, 3], "Y_corner must be either 1 or 3" if not font_size: font_size = self._default_font_size x, horizontal_alignment = self._align_x_coordinate(box_coordinate) y = box_coordinate[y_corner].item() for i, text in enumerate(text_ls): self.draw_text( text, (x, y), font_size=font_size, color=color, horizontal_alignment=horizontal_alignment, vertical_alignment="bottom", box_facecolor=box_facecolors[i], alpha=alpha, ) y -= font_size + font_size // 2 def draw_multiple_text_downward( self, text_ls: List[str], box_coordinate: torch.Tensor, *, y_corner: int = 1, font_size: Optional[int] = None, color: str = "w", box_facecolors: str = "black", alpha: float = 0.5, ) -> None: """ Draw a list of text labels for some bounding box on the image in downward direction. The next text label will be below the previous one. Args: text_ls (list of strings): a list of text labels. box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) coordinates of the box. y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of the box to draw labels around. font_size (Optional[int]): font of the text. If not provided, a font size proportional to the image width is calculated and used. color (str): color of the text. Refer to `matplotlib.colors` for full list of formats that are accepted. box_facecolors (str): colors of the box wrapped around the text. Refer to `matplotlib.colors` for full list of formats that are accepted. alpha (float): transparency level of the box. """ if not isinstance(box_facecolors, list): box_facecolors = [box_facecolors] * len(text_ls) assert len(box_facecolors) == len( text_ls ), "Number of colors provided is not equal to the number of text labels." assert y_corner in [1, 3], "Y_corner must be either 1 or 3" if not font_size: font_size = self._default_font_size x, horizontal_alignment = self._align_x_coordinate(box_coordinate) y = box_coordinate[y_corner].item() for i, text in enumerate(text_ls): self.draw_text( text, (x, y), font_size=font_size, color=color, horizontal_alignment=horizontal_alignment, vertical_alignment="top", box_facecolor=box_facecolors[i], alpha=alpha, ) y += font_size + font_size // 2 def _align_x_coordinate(self, box_coordinate: torch.Tensor) -> Tuple[float, str]: """ Choose an x-coordinate from the box to make sure the text label does not go out of frames. By default, the left x-coordinate is chosen and text is aligned left. If the box is too close to the right side of the image, then the right x-coordinate is chosen instead and the text is aligned right. Args: box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) coordinates of the box. Returns: x_coordinate (float): the chosen x-coordinate. alignment (str): whether to align left or right. """ # If the x-coordinate is greater than 5/6 of the image width, # then we align test to the right of the box. This is # chosen by heuristics. if box_coordinate[0] > (self.output.width * 5) // 6: return box_coordinate[2], "right" return box_coordinate[0], "left" def _align_y_top( self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float ) -> int: """ Calculate the number of text labels to plot on top of the box without going out of frames. Args: box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) coordinates of the box. num_text (int): the number of text labels to plot. textbox_width (float): the width of the box wrapped around text label. """ dist_to_top = box_coordinate[1] num_text_top = dist_to_top // textbox_width if isinstance(num_text_top, torch.Tensor): num_text_top = int(num_text_top.item()) return min(num_text, num_text_top) def _align_y_bottom( self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float ) -> int: """ Calculate the number of text labels to plot at the bottom of the box without going out of frames. Args: box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) coordinates of the box. num_text (int): the number of text labels to plot. textbox_width (float): the width of the box wrapped around text label. """ dist_to_bottom = self.output.height - box_coordinate[3] num_text_bottom = dist_to_bottom // textbox_width if isinstance(num_text_bottom, torch.Tensor): num_text_bottom = int(num_text_bottom.item()) return min(num_text, num_text_bottom) class VideoVisualizer: def __init__( self, num_classes: int, class_names: Dict, top_k: int = 1, colormap: str = "rainbow", thres: float = 0.7, lower_thres: float = 0.3, common_class_names: Optional[List[str]] = None, mode: str = "top-k", ) -> None: """ Args: num_classes (int): total number of classes. class_names (dict): Dict mapping classID to name. top_k (int): number of top predicted classes to plot. colormap (str): the colormap to choose color for class labels from. See https://matplotlib.org/tutorials/colors/colormaps.html thres (float): threshold for picking predicted classes to visualize. lower_thres (Optional[float]): If `common_class_names` if given, this `lower_thres` will be applied to uncommon classes and `thres` will be applied to classes in `common_class_names`. common_class_names (Optional[list of str]): list of common class names to apply `thres`. Class names not included in `common_class_names` will have `lower_thres` as a threshold. If None, all classes will have `thres` as a threshold. This is helpful for model trained on highly imbalanced dataset. mode (str): Supported modes are {"top-k", "thres"}. This is used for choosing predictions for visualization. """ assert mode in ["top-k", "thres"], "Mode {} is not supported.".format(mode) self.mode = mode self.num_classes = num_classes self.class_names = class_names self.top_k = top_k self.thres = thres self.lower_thres = lower_thres if mode == "thres": self._get_thres_array(common_class_names=common_class_names) self.color_map = plt.get_cmap(colormap) def _get_color(self, class_id: int) -> List[float]: """ Get color for a class id. Args: class_id (int): class id. """ return self.color_map(class_id / self.num_classes)[:3] def draw_one_frame( self, frame: Union[torch.Tensor, np.ndarray], preds: Union[torch.Tensor, List[float]], bboxes: Optional[torch.Tensor] = None, alpha: float = 0.5, text_alpha: float = 0.7, ground_truth: bool = False, ) -> np.ndarray: """ Draw labels and bouding boxes for one image. By default, predicted labels are drawn in the top left corner of the image or corresponding bounding boxes. For ground truth labels (setting True for ground_truth flag), labels will be drawn in the bottom left corner. Args: frame (array-like): a tensor or numpy array of shape (H, W, C), where H and W correspond to the height and width of the image respectively. C is the number of color channels. The image is required to be in RGB format since that is a requirement of the Matplotlib library. The image is also expected to be in the range [0, 255]. preds (tensor or list): If ground_truth is False, provide a float tensor of shape (num_boxes, num_classes) that contains all of the confidence scores of the model. For recognition task, input shape can be (num_classes,). To plot true label (ground_truth is True), preds is a list contains int32 of the shape (num_boxes, true_class_ids) or (true_class_ids,). bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates of the bounding boxes. alpha (Optional[float]): transparency level of the bounding boxes. text_alpha (Optional[float]): transparency level of the box wrapped around text labels. ground_truth (bool): whether the prodived bounding boxes are ground-truth. Returns: An image with bounding box annotations and corresponding bbox labels plotted on it. """ if isinstance(preds, torch.Tensor): if preds.ndim == 1: preds = preds.unsqueeze(0) n_instances = preds.shape[0] elif isinstance(preds, list): n_instances = len(preds) else: # logger.error("Unsupported type of prediction input.") return if ground_truth: top_scores, top_classes = [None] * n_instances, preds elif self.mode == "top-k": top_scores, top_classes = torch.topk(preds, k=self.top_k) top_scores, top_classes = top_scores.tolist(), top_classes.tolist() elif self.mode == "thres": top_scores, top_classes = [], [] for pred in preds: mask = pred >= self.thres top_scores.append(pred[mask].tolist()) top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist() top_classes.append(top_class) # Create labels top k predicted classes with their scores. text_labels = [] for i in range(n_instances): text_labels.append( _create_text_labels( top_classes[i], top_scores[i], self.class_names, ground_truth=ground_truth, ) ) frame_visualizer = ImgVisualizer(frame, meta=None) font_size = min(max(np.sqrt(frame.shape[0] * frame.shape[1]) // 25, 5), 9) top_corner = not ground_truth if bboxes is not None: assert len(preds) == len( bboxes ), "Encounter {} predictions and {} bounding boxes".format( len(preds), len(bboxes) ) for i, box in enumerate(bboxes): text = text_labels[i] pred_class = top_classes[i] colors = [self._get_color(pred) for pred in pred_class] box_color = "r" if ground_truth else "g" line_style = "--" if ground_truth else "-." frame_visualizer.draw_box( box, alpha=alpha, edge_color=box_color, line_style=line_style, ) frame_visualizer.draw_multiple_text( text, box, top_corner=top_corner, font_size=font_size, box_facecolors=colors, alpha=text_alpha, ) else: text = text_labels[0] pred_class = top_classes[0] colors = [self._get_color(pred) for pred in pred_class] frame_visualizer.draw_multiple_text( text, torch.Tensor([0, 5, frame.shape[1], frame.shape[0] - 5]), top_corner=top_corner, font_size=font_size, box_facecolors=colors, alpha=text_alpha, ) return frame_visualizer.output.get_image() def draw_clip_range( self, frames: Union[torch.Tensor, np.ndarray], preds: Union[torch.Tensor, List[float]], bboxes: Optional[torch.Tensor] = None, text_alpha: float = 0.5, ground_truth: bool = False, keyframe_idx: Optional[int] = None, draw_range: Optional[List[int]] = None, repeat_frame: int = 1, ) -> List[np.ndarray]: """ Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip if bboxes is provided. Boxes will gradually fade in and out the clip, centered around the clip's central frame, within the provided `draw_range`. Args: frames (array-like): video data in the shape (T, H, W, C). preds (tensor): a tensor of shape (num_boxes, num_classes) that contains all of the confidence scores of the model. For recognition task or for ground_truth labels, input shape can be (num_classes,). bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates of the bounding boxes. text_alpha (float): transparency label of the box wrapped around text labels. ground_truth (bool): whether the prodived bounding boxes are ground-truth. keyframe_idx (int): the index of keyframe in the clip. draw_range (Optional[list[ints]): only draw frames in range [start_idx, end_idx] inclusively in the clip. If None, draw on the entire clip. repeat_frame (int): repeat each frame in draw_range for `repeat_frame` time for slow-motion effect. Returns: A list of frames with bounding box annotations and corresponding bbox labels ploted on them. """ if draw_range is None: draw_range = [0, len(frames) - 1] if draw_range is not None: draw_range[0] = max(0, draw_range[0]) left_frames = frames[: draw_range[0]] right_frames = frames[draw_range[1] + 1 :] draw_frames = frames[draw_range[0] : draw_range[1] + 1] if keyframe_idx is None: keyframe_idx = len(frames) // 2 img_ls = ( list(left_frames) + self.draw_clip( draw_frames, preds, bboxes=bboxes, text_alpha=text_alpha, ground_truth=ground_truth, keyframe_idx=keyframe_idx - draw_range[0], repeat_frame=repeat_frame, ) + list(right_frames) ) return img_ls def draw_clip( self, frames: Union[torch.Tensor, np.ndarray], preds: Union[torch.Tensor, List[float]], bboxes: Optional[torch.Tensor] = None, text_alpha: float = 0.5, ground_truth: bool = False, keyframe_idx: Optional[int] = None, repeat_frame: int = 1, ) -> List[np.ndarray]: """ Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip if bboxes is provided. Boxes will gradually fade in and out the clip, centered around the clip's central frame. Args: frames (array-like): video data in the shape (T, H, W, C). preds (tensor): a tensor of shape (num_boxes, num_classes) that contains all of the confidence scores of the model. For recognition task or for ground_truth labels, input shape can be (num_classes,). bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates of the bounding boxes. text_alpha (float): transparency label of the box wrapped around text labels. ground_truth (bool): whether the prodived bounding boxes are ground-truth. keyframe_idx (int): the index of keyframe in the clip. repeat_frame (int): repeat each frame in draw_range for `repeat_frame` time for slow-motion effect. Returns: A list of frames with bounding box annotations and corresponding bbox labels plotted on them. """ assert repeat_frame >= 1, "`repeat_frame` must be a positive integer." repeated_seq = range(0, len(frames)) repeated_seq = list( itertools.chain.from_iterable( itertools.repeat(x, repeat_frame) for x in repeated_seq ) ) frames, adjusted = self._adjust_frames_type(frames) if keyframe_idx is None: half_left = len(repeated_seq) // 2 half_right = (len(repeated_seq) + 1) // 2 else: mid = int((keyframe_idx / len(frames)) * len(repeated_seq)) half_left = mid half_right = len(repeated_seq) - mid alpha_ls = np.concatenate( [ np.linspace(0, 1, num=half_left), np.linspace(1, 0, num=half_right), ] ) text_alpha = text_alpha frames = frames[repeated_seq] img_ls = [] for alpha, frame in zip(alpha_ls, frames): draw_img = self.draw_one_frame( frame, preds, bboxes, alpha=alpha, text_alpha=text_alpha, ground_truth=ground_truth, ) if adjusted: draw_img = draw_img.astype("float32") / 255 img_ls.append(draw_img) return img_ls def _adjust_frames_type( self, frames: torch.Tensor ) -> Tuple[List[np.ndarray], bool]: """ Modify video data to have dtype of uint8 and values range in [0, 255]. Args: frames (array-like): 4D array of shape (T, H, W, C). Returns: frames (list of frames): list of frames in range [0, 1]. adjusted (bool): whether the original frames need adjusted. """ assert ( frames is not None and len(frames) != 0 ), "Frames does not contain any values" frames = np.array(frames) assert np.array(frames).ndim == 4, "Frames must have 4 dimensions" adjusted = False if frames.dtype in [np.float32, np.float64]: frames *= 255 frames = frames.astype(np.uint8) adjusted = True return frames, adjusted def _get_thres_array(self, common_class_names: Optional[List[str]] = None) -> None: """ Compute a thresholds array for all classes based on `self.thes` and `self.lower_thres`. Args: common_class_names (Optional[list of str]): a list of common class names. """ common_class_ids = [] if common_class_names is not None: common_classes = set(common_class_names) for key, name in self.class_names.items(): if name in common_classes: common_class_ids.append(key) else: common_class_ids = list(range(self.num_classes)) thres_array = np.full(shape=(self.num_classes,), fill_value=self.lower_thres) thres_array[common_class_ids] = self.thres self.thres = torch.from_numpy(thres_array)