|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import itertools |
|
import logging |
|
from types import SimpleNamespace |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
from detectron2.utils.visualizer import Visualizer |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def _create_text_labels( |
|
classes: List[int], |
|
scores: List[float], |
|
class_names: List[str], |
|
ground_truth: bool = False, |
|
) -> List[str]: |
|
""" |
|
Create text labels. |
|
Args: |
|
classes (list[int]): a list of class ids for each example. |
|
scores (list[float] or None): list of scores for each example. |
|
class_names (list[str]): a list of class names, ordered by their ids. |
|
ground_truth (bool): whether the labels are ground truth. |
|
Returns: |
|
labels (list[str]): formatted text labels. |
|
""" |
|
try: |
|
labels = [class_names.get(c, "n/a") for c in classes] |
|
except IndexError: |
|
logger.error("Class indices get out of range: {}".format(classes)) |
|
return None |
|
|
|
if ground_truth: |
|
labels = ["[{}] {}".format("GT", label) for label in labels] |
|
elif scores is not None: |
|
assert len(classes) == len(scores) |
|
labels = ["[{:.2f}] {}".format(s, label) for s, label in zip(scores, labels)] |
|
return labels |
|
|
|
|
|
class ImgVisualizer(Visualizer): |
|
def __init__( |
|
self, img_rgb: torch.Tensor, meta: Optional[SimpleNamespace] = None, **kwargs |
|
) -> None: |
|
""" |
|
See https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py |
|
for more details. |
|
Args: |
|
img_rgb: a tensor or numpy array of shape (H, W, C), where H and W correspond to |
|
the height and width of the image respectively. C is the number of |
|
color channels. The image is required to be in RGB format since that |
|
is a requirement of the Matplotlib library. The image is also expected |
|
to be in the range [0, 255]. |
|
meta (MetadataCatalog): image metadata. |
|
See https://github.com/facebookresearch/detectron2/blob/81d5a87763bfc71a492b5be89b74179bd7492f6b/detectron2/data/catalog.py#L90 |
|
""" |
|
super(ImgVisualizer, self).__init__(img_rgb, meta, **kwargs) |
|
|
|
def draw_text( |
|
self, |
|
text: str, |
|
position: List[int], |
|
*, |
|
font_size: Optional[int] = None, |
|
color: str = "w", |
|
horizontal_alignment: str = "center", |
|
vertical_alignment: str = "bottom", |
|
box_facecolor: str = "black", |
|
alpha: float = 0.5, |
|
) -> None: |
|
""" |
|
Draw text at the specified position. |
|
Args: |
|
text (str): the text to draw on image. |
|
position (list of 2 ints): the x,y coordinate to place the text. |
|
font_size (Optional[int]): font of the text. If not provided, a font size |
|
proportional to the image width is calculated and used. |
|
color (str): color of the text. Refer to `matplotlib.colors` for full list |
|
of formats that are accepted. |
|
horizontal_alignment (str): see `matplotlib.text.Text`. |
|
vertical_alignment (str): see `matplotlib.text.Text`. |
|
box_facecolor (str): color of the box wrapped around the text. Refer to |
|
`matplotlib.colors` for full list of formats that are accepted. |
|
alpha (float): transparency level of the box. |
|
""" |
|
if not font_size: |
|
font_size = self._default_font_size |
|
x, y = position |
|
self.output.ax.text( |
|
x, |
|
y, |
|
text, |
|
size=font_size * self.output.scale, |
|
family="monospace", |
|
bbox={ |
|
"facecolor": box_facecolor, |
|
"alpha": alpha, |
|
"pad": 0.7, |
|
"edgecolor": "none", |
|
}, |
|
verticalalignment=vertical_alignment, |
|
horizontalalignment=horizontal_alignment, |
|
color=color, |
|
zorder=10, |
|
) |
|
|
|
def draw_multiple_text( |
|
self, |
|
text_ls: List[str], |
|
box_coordinate: torch.Tensor, |
|
*, |
|
top_corner: bool = True, |
|
font_size: Optional[int] = None, |
|
color: str = "w", |
|
box_facecolors: str = "black", |
|
alpha: float = 0.5, |
|
) -> None: |
|
""" |
|
Draw a list of text labels for some bounding box on the image. |
|
Args: |
|
text_ls (list of strings): a list of text labels. |
|
box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) |
|
coordinates of the box. |
|
top_corner (bool): If True, draw the text labels at (x_left, y_top) of the box. |
|
Else, draw labels at (x_left, y_bottom). |
|
font_size (Optional[int]): font of the text. If not provided, a font size |
|
proportional to the image width is calculated and used. |
|
color (str): color of the text. Refer to `matplotlib.colors` for full list |
|
of formats that are accepted. |
|
box_facecolors (str): colors of the box wrapped around the text. Refer to |
|
`matplotlib.colors` for full list of formats that are accepted. |
|
alpha (float): transparency level of the box. |
|
""" |
|
if not isinstance(box_facecolors, list): |
|
box_facecolors = [box_facecolors] * len(text_ls) |
|
assert len(box_facecolors) == len( |
|
text_ls |
|
), "Number of colors provided is not equal to the number of text labels." |
|
if not font_size: |
|
font_size = self._default_font_size |
|
text_box_width = font_size + font_size // 2 |
|
|
|
|
|
if top_corner: |
|
num_text_split = self._align_y_top( |
|
box_coordinate, len(text_ls), text_box_width |
|
) |
|
y_corner = 1 |
|
else: |
|
num_text_split = len(text_ls) - self._align_y_bottom( |
|
box_coordinate, len(text_ls), text_box_width |
|
) |
|
y_corner = 3 |
|
|
|
text_color_sorted = sorted( |
|
zip(text_ls, box_facecolors), key=lambda x: x[0], reverse=True |
|
) |
|
if len(text_color_sorted) != 0: |
|
text_ls, box_facecolors = zip(*text_color_sorted) |
|
else: |
|
text_ls, box_facecolors = [], [] |
|
text_ls, box_facecolors = list(text_ls), list(box_facecolors) |
|
self.draw_multiple_text_upward( |
|
text_ls[:num_text_split][::-1], |
|
box_coordinate, |
|
y_corner=y_corner, |
|
font_size=font_size, |
|
color=color, |
|
box_facecolors=box_facecolors[:num_text_split][::-1], |
|
alpha=alpha, |
|
) |
|
self.draw_multiple_text_downward( |
|
text_ls[num_text_split:], |
|
box_coordinate, |
|
y_corner=y_corner, |
|
font_size=font_size, |
|
color=color, |
|
box_facecolors=box_facecolors[num_text_split:], |
|
alpha=alpha, |
|
) |
|
|
|
def draw_multiple_text_upward( |
|
self, |
|
text_ls: List[str], |
|
box_coordinate: torch.Tensor, |
|
*, |
|
y_corner: int = 1, |
|
font_size: Optional[int] = None, |
|
color: str = "w", |
|
box_facecolors: str = "black", |
|
alpha: float = 0.5, |
|
) -> None: |
|
""" |
|
Draw a list of text labels for some bounding box on the image in upward direction. |
|
The next text label will be on top of the previous one. |
|
Args: |
|
text_ls (list of strings): a list of text labels. |
|
box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) |
|
coordinates of the box. |
|
y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of |
|
the box to draw labels around. |
|
font_size (Optional[int]): font of the text. If not provided, a font size |
|
proportional to the image width is calculated and used. |
|
color (str): color of the text. Refer to `matplotlib.colors` for full list |
|
of formats that are accepted. |
|
box_facecolors (str or list of strs): colors of the box wrapped around the |
|
text. Refer to `matplotlib.colors` for full list of formats that |
|
are accepted. |
|
alpha (float): transparency level of the box. |
|
""" |
|
if not isinstance(box_facecolors, list): |
|
box_facecolors = [box_facecolors] * len(text_ls) |
|
assert len(box_facecolors) == len( |
|
text_ls |
|
), "Number of colors provided is not equal to the number of text labels." |
|
|
|
assert y_corner in [1, 3], "Y_corner must be either 1 or 3" |
|
if not font_size: |
|
font_size = self._default_font_size |
|
|
|
x, horizontal_alignment = self._align_x_coordinate(box_coordinate) |
|
y = box_coordinate[y_corner].item() |
|
for i, text in enumerate(text_ls): |
|
self.draw_text( |
|
text, |
|
(x, y), |
|
font_size=font_size, |
|
color=color, |
|
horizontal_alignment=horizontal_alignment, |
|
vertical_alignment="bottom", |
|
box_facecolor=box_facecolors[i], |
|
alpha=alpha, |
|
) |
|
y -= font_size + font_size // 2 |
|
|
|
def draw_multiple_text_downward( |
|
self, |
|
text_ls: List[str], |
|
box_coordinate: torch.Tensor, |
|
*, |
|
y_corner: int = 1, |
|
font_size: Optional[int] = None, |
|
color: str = "w", |
|
box_facecolors: str = "black", |
|
alpha: float = 0.5, |
|
) -> None: |
|
""" |
|
Draw a list of text labels for some bounding box on the image in downward direction. |
|
The next text label will be below the previous one. |
|
Args: |
|
text_ls (list of strings): a list of text labels. |
|
box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) |
|
coordinates of the box. |
|
y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of |
|
the box to draw labels around. |
|
font_size (Optional[int]): font of the text. If not provided, a font size |
|
proportional to the image width is calculated and used. |
|
color (str): color of the text. Refer to `matplotlib.colors` for full list |
|
of formats that are accepted. |
|
box_facecolors (str): colors of the box wrapped around the text. Refer to |
|
`matplotlib.colors` for full list of formats that are accepted. |
|
alpha (float): transparency level of the box. |
|
""" |
|
if not isinstance(box_facecolors, list): |
|
box_facecolors = [box_facecolors] * len(text_ls) |
|
assert len(box_facecolors) == len( |
|
text_ls |
|
), "Number of colors provided is not equal to the number of text labels." |
|
|
|
assert y_corner in [1, 3], "Y_corner must be either 1 or 3" |
|
if not font_size: |
|
font_size = self._default_font_size |
|
|
|
x, horizontal_alignment = self._align_x_coordinate(box_coordinate) |
|
y = box_coordinate[y_corner].item() |
|
for i, text in enumerate(text_ls): |
|
self.draw_text( |
|
text, |
|
(x, y), |
|
font_size=font_size, |
|
color=color, |
|
horizontal_alignment=horizontal_alignment, |
|
vertical_alignment="top", |
|
box_facecolor=box_facecolors[i], |
|
alpha=alpha, |
|
) |
|
y += font_size + font_size // 2 |
|
|
|
def _align_x_coordinate(self, box_coordinate: torch.Tensor) -> Tuple[float, str]: |
|
""" |
|
Choose an x-coordinate from the box to make sure the text label |
|
does not go out of frames. By default, the left x-coordinate is |
|
chosen and text is aligned left. If the box is too close to the |
|
right side of the image, then the right x-coordinate is chosen |
|
instead and the text is aligned right. |
|
Args: |
|
box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) |
|
coordinates of the box. |
|
Returns: |
|
x_coordinate (float): the chosen x-coordinate. |
|
alignment (str): whether to align left or right. |
|
""" |
|
|
|
|
|
|
|
if box_coordinate[0] > (self.output.width * 5) // 6: |
|
return box_coordinate[2], "right" |
|
|
|
return box_coordinate[0], "left" |
|
|
|
def _align_y_top( |
|
self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float |
|
) -> int: |
|
""" |
|
Calculate the number of text labels to plot on top of the box |
|
without going out of frames. |
|
Args: |
|
box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) |
|
coordinates of the box. |
|
num_text (int): the number of text labels to plot. |
|
textbox_width (float): the width of the box wrapped around text label. |
|
""" |
|
dist_to_top = box_coordinate[1] |
|
num_text_top = dist_to_top // textbox_width |
|
|
|
if isinstance(num_text_top, torch.Tensor): |
|
num_text_top = int(num_text_top.item()) |
|
|
|
return min(num_text, num_text_top) |
|
|
|
def _align_y_bottom( |
|
self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float |
|
) -> int: |
|
""" |
|
Calculate the number of text labels to plot at the bottom of the box |
|
without going out of frames. |
|
Args: |
|
box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) |
|
coordinates of the box. |
|
num_text (int): the number of text labels to plot. |
|
textbox_width (float): the width of the box wrapped around text label. |
|
""" |
|
dist_to_bottom = self.output.height - box_coordinate[3] |
|
num_text_bottom = dist_to_bottom // textbox_width |
|
|
|
if isinstance(num_text_bottom, torch.Tensor): |
|
num_text_bottom = int(num_text_bottom.item()) |
|
|
|
return min(num_text, num_text_bottom) |
|
|
|
|
|
class VideoVisualizer: |
|
def __init__( |
|
self, |
|
num_classes: int, |
|
class_names: Dict, |
|
top_k: int = 1, |
|
colormap: str = "rainbow", |
|
thres: float = 0.7, |
|
lower_thres: float = 0.3, |
|
common_class_names: Optional[List[str]] = None, |
|
mode: str = "top-k", |
|
) -> None: |
|
""" |
|
Args: |
|
num_classes (int): total number of classes. |
|
class_names (dict): Dict mapping classID to name. |
|
top_k (int): number of top predicted classes to plot. |
|
colormap (str): the colormap to choose color for class labels from. |
|
See https://matplotlib.org/tutorials/colors/colormaps.html |
|
thres (float): threshold for picking predicted classes to visualize. |
|
lower_thres (Optional[float]): If `common_class_names` if given, |
|
this `lower_thres` will be applied to uncommon classes and |
|
`thres` will be applied to classes in `common_class_names`. |
|
common_class_names (Optional[list of str]): list of common class names |
|
to apply `thres`. Class names not included in `common_class_names` will |
|
have `lower_thres` as a threshold. If None, all classes will have |
|
`thres` as a threshold. This is helpful for model trained on |
|
highly imbalanced dataset. |
|
mode (str): Supported modes are {"top-k", "thres"}. |
|
This is used for choosing predictions for visualization. |
|
|
|
""" |
|
assert mode in ["top-k", "thres"], "Mode {} is not supported.".format(mode) |
|
self.mode = mode |
|
self.num_classes = num_classes |
|
self.class_names = class_names |
|
self.top_k = top_k |
|
self.thres = thres |
|
self.lower_thres = lower_thres |
|
|
|
if mode == "thres": |
|
self._get_thres_array(common_class_names=common_class_names) |
|
|
|
self.color_map = plt.get_cmap(colormap) |
|
|
|
def _get_color(self, class_id: int) -> List[float]: |
|
""" |
|
Get color for a class id. |
|
Args: |
|
class_id (int): class id. |
|
""" |
|
return self.color_map(class_id / self.num_classes)[:3] |
|
|
|
def draw_one_frame( |
|
self, |
|
frame: Union[torch.Tensor, np.ndarray], |
|
preds: Union[torch.Tensor, List[float]], |
|
bboxes: Optional[torch.Tensor] = None, |
|
alpha: float = 0.5, |
|
text_alpha: float = 0.7, |
|
ground_truth: bool = False, |
|
) -> np.ndarray: |
|
""" |
|
Draw labels and bouding boxes for one image. By default, predicted |
|
labels are drawn in the top left corner of the image or corresponding |
|
bounding boxes. For ground truth labels (setting True for ground_truth flag), |
|
labels will be drawn in the bottom left corner. |
|
Args: |
|
frame (array-like): a tensor or numpy array of shape (H, W, C), |
|
where H and W correspond to |
|
the height and width of the image respectively. C is the number of |
|
color channels. The image is required to be in RGB format since that |
|
is a requirement of the Matplotlib library. The image is also expected |
|
to be in the range [0, 255]. |
|
preds (tensor or list): If ground_truth is False, provide a float tensor of |
|
shape (num_boxes, num_classes) that contains all of the confidence |
|
scores of the model. For recognition task, input shape can be (num_classes,). |
|
To plot true label (ground_truth is True), preds is a list contains int32 |
|
of the shape (num_boxes, true_class_ids) or (true_class_ids,). |
|
bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates |
|
of the bounding boxes. |
|
alpha (Optional[float]): transparency level of the bounding boxes. |
|
text_alpha (Optional[float]): transparency level of the box wrapped around |
|
text labels. |
|
ground_truth (bool): whether the prodived bounding boxes are ground-truth. |
|
Returns: |
|
An image with bounding box annotations and corresponding bbox |
|
labels plotted on it. |
|
""" |
|
if isinstance(preds, torch.Tensor): |
|
if preds.ndim == 1: |
|
preds = preds.unsqueeze(0) |
|
n_instances = preds.shape[0] |
|
elif isinstance(preds, list): |
|
n_instances = len(preds) |
|
else: |
|
logger.error("Unsupported type of prediction input.") |
|
return |
|
|
|
if ground_truth: |
|
top_scores, top_classes = [None] * n_instances, preds |
|
|
|
elif self.mode == "top-k": |
|
top_scores, top_classes = torch.topk(preds, k=self.top_k) |
|
top_scores, top_classes = top_scores.tolist(), top_classes.tolist() |
|
elif self.mode == "thres": |
|
top_scores, top_classes = [], [] |
|
for pred in preds: |
|
mask = pred >= self.thres |
|
top_scores.append(pred[mask].tolist()) |
|
top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist() |
|
top_classes.append(top_class) |
|
|
|
|
|
text_labels = [] |
|
for i in range(n_instances): |
|
text_labels.append( |
|
_create_text_labels( |
|
top_classes[i], |
|
top_scores[i], |
|
self.class_names, |
|
ground_truth=ground_truth, |
|
) |
|
) |
|
frame_visualizer = ImgVisualizer(frame, meta=None) |
|
font_size = min(max(np.sqrt(frame.shape[0] * frame.shape[1]) // 25, 5), 9) |
|
top_corner = not ground_truth |
|
if bboxes is not None: |
|
assert len(preds) == len( |
|
bboxes |
|
), "Encounter {} predictions and {} bounding boxes".format( |
|
len(preds), len(bboxes) |
|
) |
|
for i, box in enumerate(bboxes): |
|
text = text_labels[i] |
|
pred_class = top_classes[i] |
|
colors = [self._get_color(pred) for pred in pred_class] |
|
|
|
box_color = "r" if ground_truth else "g" |
|
line_style = "--" if ground_truth else "-." |
|
frame_visualizer.draw_box( |
|
box, |
|
alpha=alpha, |
|
edge_color=box_color, |
|
line_style=line_style, |
|
) |
|
frame_visualizer.draw_multiple_text( |
|
text, |
|
box, |
|
top_corner=top_corner, |
|
font_size=font_size, |
|
box_facecolors=colors, |
|
alpha=text_alpha, |
|
) |
|
else: |
|
text = text_labels[0] |
|
pred_class = top_classes[0] |
|
colors = [self._get_color(pred) for pred in pred_class] |
|
frame_visualizer.draw_multiple_text( |
|
text, |
|
torch.Tensor([0, 5, frame.shape[1], frame.shape[0] - 5]), |
|
top_corner=top_corner, |
|
font_size=font_size, |
|
box_facecolors=colors, |
|
alpha=text_alpha, |
|
) |
|
|
|
return frame_visualizer.output.get_image() |
|
|
|
def draw_clip_range( |
|
self, |
|
frames: Union[torch.Tensor, np.ndarray], |
|
preds: Union[torch.Tensor, List[float]], |
|
bboxes: Optional[torch.Tensor] = None, |
|
text_alpha: float = 0.5, |
|
ground_truth: bool = False, |
|
keyframe_idx: Optional[int] = None, |
|
draw_range: Optional[List[int]] = None, |
|
repeat_frame: int = 1, |
|
) -> List[np.ndarray]: |
|
""" |
|
Draw predicted labels or ground truth classes to clip. |
|
Draw bouding boxes to clip if bboxes is provided. Boxes will gradually |
|
fade in and out the clip, centered around the clip's central frame, |
|
within the provided `draw_range`. |
|
Args: |
|
frames (array-like): video data in the shape (T, H, W, C). |
|
preds (tensor): a tensor of shape (num_boxes, num_classes) that |
|
contains all of the confidence scores of the model. For recognition |
|
task or for ground_truth labels, input shape can be (num_classes,). |
|
bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates |
|
of the bounding boxes. |
|
text_alpha (float): transparency label of the box wrapped around text labels. |
|
ground_truth (bool): whether the prodived bounding boxes are ground-truth. |
|
keyframe_idx (int): the index of keyframe in the clip. |
|
draw_range (Optional[list[ints]): only draw frames in range |
|
[start_idx, end_idx] inclusively in the clip. If None, draw on |
|
the entire clip. |
|
repeat_frame (int): repeat each frame in draw_range for `repeat_frame` |
|
time for slow-motion effect. |
|
Returns: |
|
A list of frames with bounding box annotations and corresponding |
|
bbox labels ploted on them. |
|
""" |
|
if draw_range is None: |
|
draw_range = [0, len(frames) - 1] |
|
if draw_range is not None: |
|
draw_range[0] = max(0, draw_range[0]) |
|
left_frames = frames[: draw_range[0]] |
|
right_frames = frames[draw_range[1] + 1 :] |
|
|
|
draw_frames = frames[draw_range[0] : draw_range[1] + 1] |
|
if keyframe_idx is None: |
|
keyframe_idx = len(frames) // 2 |
|
|
|
img_ls = ( |
|
list(left_frames) |
|
+ self.draw_clip( |
|
draw_frames, |
|
preds, |
|
bboxes=bboxes, |
|
text_alpha=text_alpha, |
|
ground_truth=ground_truth, |
|
keyframe_idx=keyframe_idx - draw_range[0], |
|
repeat_frame=repeat_frame, |
|
) |
|
+ list(right_frames) |
|
) |
|
|
|
return img_ls |
|
|
|
def draw_clip( |
|
self, |
|
frames: Union[torch.Tensor, np.ndarray], |
|
preds: Union[torch.Tensor, List[float]], |
|
bboxes: Optional[torch.Tensor] = None, |
|
text_alpha: float = 0.5, |
|
ground_truth: bool = False, |
|
keyframe_idx: Optional[int] = None, |
|
repeat_frame: int = 1, |
|
) -> List[np.ndarray]: |
|
""" |
|
Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip |
|
if bboxes is provided. Boxes will gradually fade in and out the clip, centered |
|
around the clip's central frame. |
|
Args: |
|
frames (array-like): video data in the shape (T, H, W, C). |
|
preds (tensor): a tensor of shape (num_boxes, num_classes) that contains |
|
all of the confidence scores of the model. For recognition task or for |
|
ground_truth labels, input shape can be (num_classes,). |
|
bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates |
|
of the bounding boxes. |
|
text_alpha (float): transparency label of the box wrapped around text labels. |
|
ground_truth (bool): whether the prodived bounding boxes are ground-truth. |
|
keyframe_idx (int): the index of keyframe in the clip. |
|
repeat_frame (int): repeat each frame in draw_range for `repeat_frame` |
|
time for slow-motion effect. |
|
Returns: |
|
A list of frames with bounding box annotations and corresponding |
|
bbox labels plotted on them. |
|
""" |
|
assert repeat_frame >= 1, "`repeat_frame` must be a positive integer." |
|
|
|
repeated_seq = range(0, len(frames)) |
|
repeated_seq = list( |
|
itertools.chain.from_iterable( |
|
itertools.repeat(x, repeat_frame) for x in repeated_seq |
|
) |
|
) |
|
|
|
frames, adjusted = self._adjust_frames_type(frames) |
|
if keyframe_idx is None: |
|
half_left = len(repeated_seq) // 2 |
|
half_right = (len(repeated_seq) + 1) // 2 |
|
else: |
|
mid = int((keyframe_idx / len(frames)) * len(repeated_seq)) |
|
half_left = mid |
|
half_right = len(repeated_seq) - mid |
|
|
|
alpha_ls = np.concatenate( |
|
[ |
|
np.linspace(0, 1, num=half_left), |
|
np.linspace(1, 0, num=half_right), |
|
] |
|
) |
|
text_alpha = text_alpha |
|
frames = frames[repeated_seq] |
|
img_ls = [] |
|
for alpha, frame in zip(alpha_ls, frames): |
|
draw_img = self.draw_one_frame( |
|
frame, |
|
preds, |
|
bboxes, |
|
alpha=alpha, |
|
text_alpha=text_alpha, |
|
ground_truth=ground_truth, |
|
) |
|
if adjusted: |
|
draw_img = draw_img.astype("float32") / 255 |
|
|
|
img_ls.append(draw_img) |
|
|
|
return img_ls |
|
|
|
def _adjust_frames_type( |
|
self, frames: torch.Tensor |
|
) -> Tuple[List[np.ndarray], bool]: |
|
""" |
|
Modify video data to have dtype of uint8 and values range in [0, 255]. |
|
Args: |
|
frames (array-like): 4D array of shape (T, H, W, C). |
|
Returns: |
|
frames (list of frames): list of frames in range [0, 1]. |
|
adjusted (bool): whether the original frames need adjusted. |
|
""" |
|
assert ( |
|
frames is not None and len(frames) != 0 |
|
), "Frames does not contain any values" |
|
frames = np.array(frames) |
|
assert np.array(frames).ndim == 4, "Frames must have 4 dimensions" |
|
adjusted = False |
|
if frames.dtype in [np.float32, np.float64]: |
|
frames *= 255 |
|
frames = frames.astype(np.uint8) |
|
adjusted = True |
|
|
|
return frames, adjusted |
|
|
|
def _get_thres_array(self, common_class_names: Optional[List[str]] = None) -> None: |
|
""" |
|
Compute a thresholds array for all classes based on `self.thes` and `self.lower_thres`. |
|
Args: |
|
common_class_names (Optional[list of str]): a list of common class names. |
|
""" |
|
common_class_ids = [] |
|
if common_class_names is not None: |
|
common_classes = set(common_class_names) |
|
|
|
for key, name in self.class_names.items(): |
|
if name in common_classes: |
|
common_class_ids.append(key) |
|
else: |
|
common_class_ids = list(range(self.num_classes)) |
|
|
|
thres_array = np.full(shape=(self.num_classes,), fill_value=self.lower_thres) |
|
thres_array[common_class_ids] = self.thres |
|
self.thres = torch.from_numpy(thres_array) |
|
|