File size: 29,225 Bytes
6155c0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Note: This file has been barrowed from facebookresearch/slowfast repo. And it is used to add the bounding boxes and predictions to the frame. 
# TODO: Migrate this into the core PyTorchVideo libarary.
from __future__ import annotations

import itertools
# import logging
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
from detectron2.utils.visualizer import Visualizer


# logger = logging.getLogger(__name__)


def _create_text_labels(
    classes: List[int],
    scores: List[float],
    class_names: List[str],
    ground_truth: bool = False,
) -> List[str]:
    """
    Create text labels.
    Args:
        classes (list[int]): a list of class ids for each example.
        scores (list[float] or None): list of scores for each example.
        class_names (list[str]): a list of class names, ordered by their ids.
        ground_truth (bool): whether the labels are ground truth.
    Returns:
        labels (list[str]): formatted text labels.
    """
    try:
        labels = [class_names.get(c, "n/a") for c in classes]
    except IndexError:
        # logger.error("Class indices get out of range: {}".format(classes))
        return None

    if ground_truth:
        labels = ["[{}] {}".format("GT", label) for label in labels]
    elif scores is not None:
        assert len(classes) == len(scores)
        labels = ["[{:.2f}] {}".format(s, label) for s, label in zip(scores, labels)]
    return labels


class ImgVisualizer(Visualizer):
    def __init__(
        self, img_rgb: torch.Tensor, meta: Optional[SimpleNamespace] = None, **kwargs
    ) -> None:
        """
        See https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py
        for more details.
        Args:
            img_rgb: a tensor or numpy array of shape (H, W, C), where H and W correspond to
                the height and width of the image respectively. C is the number of
                color channels. The image is required to be in RGB format since that
                is a requirement of the Matplotlib library. The image is also expected
                to be in the range [0, 255].
            meta (MetadataCatalog): image metadata.
                See https://github.com/facebookresearch/detectron2/blob/81d5a87763bfc71a492b5be89b74179bd7492f6b/detectron2/data/catalog.py#L90
        """
        super(ImgVisualizer, self).__init__(img_rgb, meta, **kwargs)

    def draw_text(
        self,
        text: str,
        position: List[int],
        *,
        font_size: Optional[int] = None,
        color: str = "w",
        horizontal_alignment: str = "center",
        vertical_alignment: str = "bottom",
        box_facecolor: str = "black",
        alpha: float = 0.5,
    ) -> None:
        """
        Draw text at the specified position.
        Args:
            text (str): the text to draw on image.
            position (list of 2 ints): the x,y coordinate to place the text.
            font_size (Optional[int]): font of the text. If not provided, a font size
                proportional to the image width is calculated and used.
            color (str): color of the text. Refer to `matplotlib.colors` for full list
                of formats that are accepted.
            horizontal_alignment (str): see `matplotlib.text.Text`.
            vertical_alignment (str): see `matplotlib.text.Text`.
            box_facecolor (str): color of the box wrapped around the text. Refer to
                `matplotlib.colors` for full list of formats that are accepted.
            alpha (float): transparency level of the box.
        """
        if not font_size:
            font_size = self._default_font_size
        x, y = position
        self.output.ax.text(
            x,
            y,
            text,
            size=font_size * self.output.scale,
            family="monospace",
            bbox={
                "facecolor": box_facecolor,
                "alpha": alpha,
                "pad": 0.7,
                "edgecolor": "none",
            },
            verticalalignment=vertical_alignment,
            horizontalalignment=horizontal_alignment,
            color=color,
            zorder=10,
        )

    def draw_multiple_text(
        self,
        text_ls: List[str],
        box_coordinate: torch.Tensor,
        *,
        top_corner: bool = True,
        font_size: Optional[int] = None,
        color: str = "w",
        box_facecolors: str = "black",
        alpha: float = 0.5,
    ) -> None:
        """
        Draw a list of text labels for some bounding box on the image.
        Args:
            text_ls (list of strings): a list of text labels.
            box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom)
                coordinates of the box.
            top_corner (bool): If True, draw the text labels at (x_left, y_top) of the box.
                Else, draw labels at (x_left, y_bottom).
            font_size (Optional[int]): font of the text. If not provided, a font size
                proportional to the image width is calculated and used.
            color (str): color of the text. Refer to `matplotlib.colors` for full list
                of formats that are accepted.
            box_facecolors (str): colors of the box wrapped around the text. Refer to
                `matplotlib.colors` for full list of formats that are accepted.
            alpha (float): transparency level of the box.
        """
        if not isinstance(box_facecolors, list):
            box_facecolors = [box_facecolors] * len(text_ls)
        assert len(box_facecolors) == len(
            text_ls
        ), "Number of colors provided is not equal to the number of text labels."
        if not font_size:
            font_size = self._default_font_size
        text_box_width = font_size + font_size // 2
        # If the texts does not fit in the assigned location,
        # we split the text and draw it in another place.
        if top_corner:
            num_text_split = self._align_y_top(
                box_coordinate, len(text_ls), text_box_width
            )
            y_corner = 1
        else:
            num_text_split = len(text_ls) - self._align_y_bottom(
                box_coordinate, len(text_ls), text_box_width
            )
            y_corner = 3

        text_color_sorted = sorted(
            zip(text_ls, box_facecolors), key=lambda x: x[0], reverse=True
        )
        if len(text_color_sorted) != 0:
            text_ls, box_facecolors = zip(*text_color_sorted)
        else:
            text_ls, box_facecolors = [], []
        text_ls, box_facecolors = list(text_ls), list(box_facecolors)
        self.draw_multiple_text_upward(
            text_ls[:num_text_split][::-1],
            box_coordinate,
            y_corner=y_corner,
            font_size=font_size,
            color=color,
            box_facecolors=box_facecolors[:num_text_split][::-1],
            alpha=alpha,
        )
        self.draw_multiple_text_downward(
            text_ls[num_text_split:],
            box_coordinate,
            y_corner=y_corner,
            font_size=font_size,
            color=color,
            box_facecolors=box_facecolors[num_text_split:],
            alpha=alpha,
        )

    def draw_multiple_text_upward(
        self,
        text_ls: List[str],
        box_coordinate: torch.Tensor,
        *,
        y_corner: int = 1,
        font_size: Optional[int] = None,
        color: str = "w",
        box_facecolors: str = "black",
        alpha: float = 0.5,
    ) -> None:
        """
        Draw a list of text labels for some bounding box on the image in upward direction.
        The next text label will be on top of the previous one.
        Args:
            text_ls (list of strings): a list of text labels.
            box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom)
                coordinates of the box.
            y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of
                the box to draw labels around.
            font_size (Optional[int]): font of the text. If not provided, a font size
                proportional to the image width is calculated and used.
            color (str): color of the text. Refer to `matplotlib.colors` for full list
                of formats that are accepted.
            box_facecolors (str or list of strs): colors of the box wrapped around the
                text. Refer to `matplotlib.colors` for full list of formats that
                are accepted.
            alpha (float): transparency level of the box.
        """
        if not isinstance(box_facecolors, list):
            box_facecolors = [box_facecolors] * len(text_ls)
        assert len(box_facecolors) == len(
            text_ls
        ), "Number of colors provided is not equal to the number of text labels."

        assert y_corner in [1, 3], "Y_corner must be either 1 or 3"
        if not font_size:
            font_size = self._default_font_size

        x, horizontal_alignment = self._align_x_coordinate(box_coordinate)
        y = box_coordinate[y_corner].item()
        for i, text in enumerate(text_ls):
            self.draw_text(
                text,
                (x, y),
                font_size=font_size,
                color=color,
                horizontal_alignment=horizontal_alignment,
                vertical_alignment="bottom",
                box_facecolor=box_facecolors[i],
                alpha=alpha,
            )
            y -= font_size + font_size // 2

    def draw_multiple_text_downward(
        self,
        text_ls: List[str],
        box_coordinate: torch.Tensor,
        *,
        y_corner: int = 1,
        font_size: Optional[int] = None,
        color: str = "w",
        box_facecolors: str = "black",
        alpha: float = 0.5,
    ) -> None:
        """
        Draw a list of text labels for some bounding box on the image in downward direction.
        The next text label will be below the previous one.
        Args:
            text_ls (list of strings): a list of text labels.
            box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom)
                coordinates of the box.
            y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of
                the box to draw labels around.
            font_size (Optional[int]): font of the text. If not provided, a font size
                proportional to the image width is calculated and used.
            color (str): color of the text. Refer to `matplotlib.colors` for full list
                of formats that are accepted.
            box_facecolors (str): colors of the box wrapped around the text. Refer to
                `matplotlib.colors` for full list of formats that are accepted.
            alpha (float): transparency level of the box.
        """
        if not isinstance(box_facecolors, list):
            box_facecolors = [box_facecolors] * len(text_ls)
        assert len(box_facecolors) == len(
            text_ls
        ), "Number of colors provided is not equal to the number of text labels."

        assert y_corner in [1, 3], "Y_corner must be either 1 or 3"
        if not font_size:
            font_size = self._default_font_size

        x, horizontal_alignment = self._align_x_coordinate(box_coordinate)
        y = box_coordinate[y_corner].item()
        for i, text in enumerate(text_ls):
            self.draw_text(
                text,
                (x, y),
                font_size=font_size,
                color=color,
                horizontal_alignment=horizontal_alignment,
                vertical_alignment="top",
                box_facecolor=box_facecolors[i],
                alpha=alpha,
            )
            y += font_size + font_size // 2

    def _align_x_coordinate(self, box_coordinate: torch.Tensor) -> Tuple[float, str]:
        """
        Choose an x-coordinate from the box to make sure the text label
        does not go out of frames. By default, the left x-coordinate is
        chosen and text is aligned left. If the box is too close to the
        right side of the image, then the right x-coordinate is chosen
        instead and the text is aligned right.
        Args:
            box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom)
            coordinates of the box.
        Returns:
            x_coordinate (float): the chosen x-coordinate.
            alignment (str): whether to align left or right.
        """
        # If the x-coordinate is greater than 5/6 of the image width,
        # then we align test to the right of the box. This is
        # chosen by heuristics.
        if box_coordinate[0] > (self.output.width * 5) // 6:
            return box_coordinate[2], "right"

        return box_coordinate[0], "left"

    def _align_y_top(
        self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float
    ) -> int:
        """
        Calculate the number of text labels to plot on top of the box
        without going out of frames.
        Args:
            box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom)
            coordinates of the box.
            num_text (int): the number of text labels to plot.
            textbox_width (float): the width of the box wrapped around text label.
        """
        dist_to_top = box_coordinate[1]
        num_text_top = dist_to_top // textbox_width

        if isinstance(num_text_top, torch.Tensor):
            num_text_top = int(num_text_top.item())

        return min(num_text, num_text_top)

    def _align_y_bottom(
        self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float
    ) -> int:
        """
        Calculate the number of text labels to plot at the bottom of the box
        without going out of frames.
        Args:
            box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom)
            coordinates of the box.
            num_text (int): the number of text labels to plot.
            textbox_width (float): the width of the box wrapped around text label.
        """
        dist_to_bottom = self.output.height - box_coordinate[3]
        num_text_bottom = dist_to_bottom // textbox_width

        if isinstance(num_text_bottom, torch.Tensor):
            num_text_bottom = int(num_text_bottom.item())

        return min(num_text, num_text_bottom)


class VideoVisualizer:
    def __init__(
        self,
        num_classes: int,
        class_names: Dict,
        top_k: int = 1,
        colormap: str = "rainbow",
        thres: float = 0.7,
        lower_thres: float = 0.3,
        common_class_names: Optional[List[str]] = None,
        mode: str = "top-k",
    ) -> None:
        """
        Args:
            num_classes (int): total number of classes.
            class_names (dict): Dict mapping classID to name.
            top_k (int): number of top predicted classes to plot.
            colormap (str): the colormap to choose color for class labels from.
                See https://matplotlib.org/tutorials/colors/colormaps.html
            thres (float): threshold for picking predicted classes to visualize.
            lower_thres (Optional[float]): If `common_class_names` if given,
                this `lower_thres` will be applied to uncommon classes and
                `thres` will be applied to classes in `common_class_names`.
            common_class_names (Optional[list of str]): list of common class names
                to apply `thres`. Class names not included in `common_class_names` will
                have `lower_thres` as a threshold. If None, all classes will have
                `thres` as a threshold. This is helpful for model trained on
                highly imbalanced dataset.
            mode (str): Supported modes are {"top-k", "thres"}.
                This is used for choosing predictions for visualization.
        """
        assert mode in ["top-k", "thres"], "Mode {} is not supported.".format(mode)
        self.mode = mode
        self.num_classes = num_classes
        self.class_names = class_names
        self.top_k = top_k
        self.thres = thres
        self.lower_thres = lower_thres

        if mode == "thres":
            self._get_thres_array(common_class_names=common_class_names)

        self.color_map = plt.get_cmap(colormap)

    def _get_color(self, class_id: int) -> List[float]:
        """
        Get color for a class id.
        Args:
            class_id (int): class id.
        """
        return self.color_map(class_id / self.num_classes)[:3]

    def draw_one_frame(
        self,
        frame: Union[torch.Tensor, np.ndarray],
        preds: Union[torch.Tensor, List[float]],
        bboxes: Optional[torch.Tensor] = None,
        alpha: float = 0.5,
        text_alpha: float = 0.7,
        ground_truth: bool = False,
    ) -> np.ndarray:
        """
        Draw labels and bouding boxes for one image. By default, predicted
        labels are drawn in the top left corner of the image or corresponding
        bounding boxes. For ground truth labels (setting True for ground_truth flag),
        labels will be drawn in the bottom left corner.
        Args:
            frame (array-like): a tensor or numpy array of shape (H, W, C),
            where H and W correspond to
                the height and width of the image respectively. C is the number of
                color channels. The image is required to be in RGB format since that
                is a requirement of the Matplotlib library. The image is also expected
                to be in the range [0, 255].
            preds (tensor or list): If ground_truth is False, provide a float tensor of
                shape (num_boxes, num_classes) that contains all of the confidence
                scores of the model. For recognition task, input shape can be (num_classes,).
                To plot true label (ground_truth is True), preds is a list contains int32
                of the shape (num_boxes, true_class_ids) or (true_class_ids,).
            bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates
                of the bounding boxes.
            alpha (Optional[float]): transparency level of the bounding boxes.
            text_alpha (Optional[float]): transparency level of the box wrapped around
                text labels.
            ground_truth (bool): whether the prodived bounding boxes are ground-truth.
        Returns:
            An image with bounding box annotations and corresponding bbox
            labels plotted on it.
        """
        if isinstance(preds, torch.Tensor):
            if preds.ndim == 1:
                preds = preds.unsqueeze(0)
            n_instances = preds.shape[0]
        elif isinstance(preds, list):
            n_instances = len(preds)
        else:
            # logger.error("Unsupported type of prediction input.")
            return

        if ground_truth:
            top_scores, top_classes = [None] * n_instances, preds

        elif self.mode == "top-k":
            top_scores, top_classes = torch.topk(preds, k=self.top_k)
            top_scores, top_classes = top_scores.tolist(), top_classes.tolist()
        elif self.mode == "thres":
            top_scores, top_classes = [], []
            for pred in preds:
                mask = pred >= self.thres
                top_scores.append(pred[mask].tolist())
                top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist()
                top_classes.append(top_class)

        # Create labels top k predicted classes with their scores.
        text_labels = []
        for i in range(n_instances):
            text_labels.append(
                _create_text_labels(
                    top_classes[i],
                    top_scores[i],
                    self.class_names,
                    ground_truth=ground_truth,
                )
            )
        frame_visualizer = ImgVisualizer(frame, meta=None)
        font_size = min(max(np.sqrt(frame.shape[0] * frame.shape[1]) // 25, 5), 9)
        top_corner = not ground_truth
        if bboxes is not None:
            assert len(preds) == len(
                bboxes
            ), "Encounter {} predictions and {} bounding boxes".format(
                len(preds), len(bboxes)
            )
            for i, box in enumerate(bboxes):
                text = text_labels[i]
                pred_class = top_classes[i]
                colors = [self._get_color(pred) for pred in pred_class]

                box_color = "r" if ground_truth else "g"
                line_style = "--" if ground_truth else "-."
                frame_visualizer.draw_box(
                    box,
                    alpha=alpha,
                    edge_color=box_color,
                    line_style=line_style,
                )
                frame_visualizer.draw_multiple_text(
                    text,
                    box,
                    top_corner=top_corner,
                    font_size=font_size,
                    box_facecolors=colors,
                    alpha=text_alpha,
                )
        else:
            text = text_labels[0]
            pred_class = top_classes[0]
            colors = [self._get_color(pred) for pred in pred_class]
            frame_visualizer.draw_multiple_text(
                text,
                torch.Tensor([0, 5, frame.shape[1], frame.shape[0] - 5]),
                top_corner=top_corner,
                font_size=font_size,
                box_facecolors=colors,
                alpha=text_alpha,
            )

        return frame_visualizer.output.get_image()

    def draw_clip_range(
        self,
        frames: Union[torch.Tensor, np.ndarray],
        preds: Union[torch.Tensor, List[float]],
        bboxes: Optional[torch.Tensor] = None,
        text_alpha: float = 0.5,
        ground_truth: bool = False,
        keyframe_idx: Optional[int] = None,
        draw_range: Optional[List[int]] = None,
        repeat_frame: int = 1,
    ) -> List[np.ndarray]:
        """
        Draw predicted labels or ground truth classes to clip.
        Draw bouding boxes to clip if bboxes is provided. Boxes will gradually
        fade in and out the clip, centered around the clip's central frame,
        within the provided `draw_range`.
        Args:
            frames (array-like): video data in the shape (T, H, W, C).
            preds (tensor): a tensor of shape (num_boxes, num_classes) that
                contains all of the confidence scores of the model. For recognition
                task or for ground_truth labels, input shape can be (num_classes,).
            bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates
                of the bounding boxes.
            text_alpha (float): transparency label of the box wrapped around text labels.
            ground_truth (bool): whether the prodived bounding boxes are ground-truth.
            keyframe_idx (int): the index of keyframe in the clip.
            draw_range (Optional[list[ints]): only draw frames in range
                [start_idx, end_idx] inclusively in the clip. If None, draw on
                the entire clip.
            repeat_frame (int): repeat each frame in draw_range for `repeat_frame`
                time for slow-motion effect.
        Returns:
            A list of frames with bounding box annotations and corresponding
            bbox labels ploted on them.
        """
        if draw_range is None:
            draw_range = [0, len(frames) - 1]
        if draw_range is not None:
            draw_range[0] = max(0, draw_range[0])
            left_frames = frames[: draw_range[0]]
            right_frames = frames[draw_range[1] + 1 :]

        draw_frames = frames[draw_range[0] : draw_range[1] + 1]
        if keyframe_idx is None:
            keyframe_idx = len(frames) // 2

        img_ls = (
            list(left_frames)
            + self.draw_clip(
                draw_frames,
                preds,
                bboxes=bboxes,
                text_alpha=text_alpha,
                ground_truth=ground_truth,
                keyframe_idx=keyframe_idx - draw_range[0],
                repeat_frame=repeat_frame,
            )
            + list(right_frames)
        )

        return img_ls

    def draw_clip(
        self,
        frames: Union[torch.Tensor, np.ndarray],
        preds: Union[torch.Tensor, List[float]],
        bboxes: Optional[torch.Tensor] = None,
        text_alpha: float = 0.5,
        ground_truth: bool = False,
        keyframe_idx: Optional[int] = None,
        repeat_frame: int = 1,
    ) -> List[np.ndarray]:
        """
        Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip
        if bboxes is provided. Boxes will gradually fade in and out the clip, centered
        around the clip's central frame.
        Args:
            frames (array-like): video data in the shape (T, H, W, C).
            preds (tensor): a tensor of shape (num_boxes, num_classes) that contains
                all of the confidence scores of the model. For recognition task or for
                ground_truth labels, input shape can be (num_classes,).
            bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates
                of the bounding boxes.
            text_alpha (float): transparency label of the box wrapped around text labels.
            ground_truth (bool): whether the prodived bounding boxes are ground-truth.
            keyframe_idx (int): the index of keyframe in the clip.
            repeat_frame (int): repeat each frame in draw_range for `repeat_frame`
                time for slow-motion effect.
        Returns:
            A list of frames with bounding box annotations and corresponding
            bbox labels plotted on them.
        """
        assert repeat_frame >= 1, "`repeat_frame` must be a positive integer."

        repeated_seq = range(0, len(frames))
        repeated_seq = list(
            itertools.chain.from_iterable(
                itertools.repeat(x, repeat_frame) for x in repeated_seq
            )
        )

        frames, adjusted = self._adjust_frames_type(frames)
        if keyframe_idx is None:
            half_left = len(repeated_seq) // 2
            half_right = (len(repeated_seq) + 1) // 2
        else:
            mid = int((keyframe_idx / len(frames)) * len(repeated_seq))
            half_left = mid
            half_right = len(repeated_seq) - mid

        alpha_ls = np.concatenate(
            [
                np.linspace(0, 1, num=half_left),
                np.linspace(1, 0, num=half_right),
            ]
        )
        text_alpha = text_alpha
        frames = frames[repeated_seq]
        img_ls = []
        for alpha, frame in zip(alpha_ls, frames):
            draw_img = self.draw_one_frame(
                frame,
                preds,
                bboxes,
                alpha=alpha,
                text_alpha=text_alpha,
                ground_truth=ground_truth,
            )
            if adjusted:
                draw_img = draw_img.astype("float32") / 255

            img_ls.append(draw_img)

        return img_ls

    def _adjust_frames_type(
        self, frames: torch.Tensor
    ) -> Tuple[List[np.ndarray], bool]:
        """
        Modify video data to have dtype of uint8 and values range in [0, 255].
        Args:
            frames (array-like): 4D array of shape (T, H, W, C).
        Returns:
            frames (list of frames): list of frames in range [0, 1].
            adjusted (bool): whether the original frames need adjusted.
        """
        assert (
            frames is not None and len(frames) != 0
        ), "Frames does not contain any values"
        frames = np.array(frames)
        assert np.array(frames).ndim == 4, "Frames must have 4 dimensions"
        adjusted = False
        if frames.dtype in [np.float32, np.float64]:
            frames *= 255
            frames = frames.astype(np.uint8)
            adjusted = True

        return frames, adjusted

    def _get_thres_array(self, common_class_names: Optional[List[str]] = None) -> None:
        """
        Compute a thresholds array for all classes based on `self.thes` and `self.lower_thres`.
        Args:
            common_class_names (Optional[list of str]): a list of common class names.
        """
        common_class_ids = []
        if common_class_names is not None:
            common_classes = set(common_class_names)

            for key, name in self.class_names.items():
                if name in common_classes:
                    common_class_ids.append(key)
        else:
            common_class_ids = list(range(self.num_classes))

        thres_array = np.full(shape=(self.num_classes,), fill_value=self.lower_thres)
        thres_array[common_class_ids] = self.thres
        self.thres = torch.from_numpy(thres_array)