File size: 5,360 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Optional, Sequence, Union

import mmcv
import mmengine.fileio as fileio
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmengine.visualization import Visualizer

from mmocr.registry import HOOKS
from mmocr.structures import TextDetDataSample, TextRecogDataSample


# TODO Files with the same name will be overwritten for multi datasets
@HOOKS.register_module()
class VisualizationHook(Hook):
    """Detection Visualization Hook. Used to visualize validation and testing
    process prediction results.

    Args:
        enable (bool): Whether to enable this hook. Defaults to False.
        interval (int): The interval of visualization. Defaults to 50.
        score_thr (float): The threshold to visualize the bboxes
            and masks. It's only useful for text detection. Defaults to 0.3.
        show (bool): Whether to display the drawn image. Defaults to False.
        wait_time (float): The interval of show in seconds. Defaults
            to 0.
        backend_args (dict, optional): Instantiates the corresponding file
            backend. It may contain `backend` key to specify the file
            backend. If it contains, the file backend corresponding to this
            value will be used and initialized with the remaining values,
            otherwise the corresponding file backend will be selected
            based on the prefix of the file path. Defaults to None.
    """

    def __init__(
        self,
        enable: bool = False,
        interval: int = 50,
        score_thr: float = 0.3,
        show: bool = False,
        draw_pred: bool = False,
        draw_gt: bool = False,
        wait_time: float = 0.,
        backend_args: Optional[dict] = None,
    ) -> None:
        self._visualizer: Visualizer = Visualizer.get_current_instance()
        self.interval = interval
        self.score_thr = score_thr
        self.show = show
        self.draw_pred = draw_pred
        self.draw_gt = draw_gt
        self.wait_time = wait_time
        self.backend_args = backend_args
        self.enable = enable

    # TODO after MultiDatasetWrapper, rewrites this function and try to merge
    # with after_val_iter and after_test_iter
    def after_val_iter(self, runner: Runner, batch_idx: int,
                       data_batch: Sequence[dict],
                       outputs: Sequence[Union[TextDetDataSample,
                                               TextRecogDataSample]]) -> None:
        """Run after every ``self.interval`` validation iterations.

        Args:
            runner (:obj:`Runner`): The runner of the validation process.
            batch_idx (int): The index of the current batch in the val loop.
            data_batch (Sequence[dict]): Data from dataloader.
            outputs (Sequence[:obj:`TextDetDataSample` or
                :obj:`TextRecogDataSample`]): Outputs from model.
        """
        # TODO: data_batch does not include annotation information
        if self.enable is False:
            return

        # There is no guarantee that the same batch of images
        # is visualized for each evaluation.
        total_curr_iter = runner.iter + batch_idx

        # Visualize only the first data
        if total_curr_iter % self.interval == 0:
            for output in outputs:
                img_path = output.img_path
                img_bytes = fileio.get(
                    img_path, backend_args=self.backend_args)
                img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
                self._visualizer.add_datasample(
                    osp.splitext(osp.basename(img_path))[0],
                    img,
                    data_sample=output,
                    draw_gt=self.draw_gt,
                    draw_pred=self.draw_pred,
                    show=self.show,
                    wait_time=self.wait_time,
                    pred_score_thr=self.score_thr,
                    step=total_curr_iter)

    def after_test_iter(self, runner: Runner, batch_idx: int,
                        data_batch: Sequence[dict],
                        outputs: Sequence[Union[TextDetDataSample,
                                                TextRecogDataSample]]) -> None:
        """Run after every testing iterations.

        Args:
            runner (:obj:`Runner`): The runner of the testing process.
            batch_idx (int): The index of the current batch in the val loop.
            data_batch (Sequence[dict]): Data from dataloader.
            outputs (Sequence[:obj:`TextDetDataSample` or
                :obj:`TextRecogDataSample`]): Outputs from model.
        """

        if self.enable is False:
            return

        for output in outputs:
            img_path = output.img_path
            img_bytes = fileio.get(img_path, backend_args=self.backend_args)
            img = mmcv.imfrombytes(img_bytes, channel_order='rgb')

            self._visualizer.add_datasample(
                osp.splitext(osp.basename(img_path))[0],
                img,
                data_sample=output,
                show=self.show,
                draw_gt=self.draw_gt,
                draw_pred=self.draw_pred,
                wait_time=self.wait_time,
                pred_score_thr=self.score_thr,
                step=batch_idx)