| | |
| |
|
| | from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| |
|
| | if TYPE_CHECKING: |
| | from matplotlib.backends.backend_agg import FigureCanvasAgg |
| |
|
| |
|
| | def tensor2ndarray(value: Union[np.ndarray, torch.Tensor]) -> np.ndarray: |
| | """If the type of value is torch.Tensor, convert the value to np.ndarray. |
| | |
| | Args: |
| | value (np.ndarray, torch.Tensor): value. |
| | |
| | Returns: |
| | Any: value. |
| | """ |
| | if isinstance(value, torch.Tensor): |
| | value = value.detach().cpu().numpy() |
| | return value |
| |
|
| |
|
| | def value2list(value: Any, valid_type: Union[Type, Tuple[Type, ...]], |
| | expand_dim: int) -> List[Any]: |
| | """If the type of ``value`` is ``valid_type``, convert the value to list |
| | and expand to ``expand_dim``. |
| | |
| | Args: |
| | value (Any): value. |
| | valid_type (Union[Type, Tuple[Type, ...]): valid type. |
| | expand_dim (int): expand dim. |
| | |
| | Returns: |
| | List[Any]: value. |
| | """ |
| | if isinstance(value, valid_type): |
| | value = [value] * expand_dim |
| | return value |
| |
|
| |
|
| | def check_type(name: str, value: Any, |
| | valid_type: Union[Type, Tuple[Type, ...]]) -> None: |
| | """Check whether the type of value is in ``valid_type``. |
| | |
| | Args: |
| | name (str): value name. |
| | value (Any): value. |
| | valid_type (Type, Tuple[Type, ...]): expected type. |
| | """ |
| | if not isinstance(value, valid_type): |
| | raise TypeError(f'`{name}` should be {valid_type} ' |
| | f' but got {type(value)}') |
| |
|
| |
|
| | def check_length(name: str, value: Any, valid_length: int) -> None: |
| | """If type of the ``value`` is list, check whether its length is equal with |
| | or greater than ``valid_length``. |
| | |
| | Args: |
| | name (str): value name. |
| | value (Any): value. |
| | valid_length (int): expected length. |
| | """ |
| | if isinstance(value, list): |
| | if len(value) < valid_length: |
| | raise AssertionError( |
| | f'The length of {name} must equal with or ' |
| | f'greater than {valid_length}, but got {len(value)}') |
| |
|
| |
|
| | def check_type_and_length(name: str, value: Any, |
| | valid_type: Union[Type, Tuple[Type, ...]], |
| | valid_length: int) -> None: |
| | """Check whether the type of value is in ``valid_type``. If type of the |
| | ``value`` is list, check whether its length is equal with or greater than |
| | ``valid_length``. |
| | |
| | Args: |
| | value (Any): value. |
| | legal_type (Type, Tuple[Type, ...]): legal type. |
| | valid_length (int): expected length. |
| | |
| | Returns: |
| | List[Any]: value. |
| | """ |
| | check_type(name, value, valid_type) |
| | check_length(name, value, valid_length) |
| |
|
| |
|
| | def color_val_matplotlib( |
| | colors: Union[str, tuple, List[Union[str, tuple]]] |
| | ) -> Union[str, tuple, List[Union[str, tuple]]]: |
| | """Convert various input in RGB order to normalized RGB matplotlib color |
| | tuples, |
| | Args: |
| | colors (Union[str, tuple, List[Union[str, tuple]]]): Color inputs |
| | Returns: |
| | Union[str, tuple, List[Union[str, tuple]]]: A tuple of 3 normalized |
| | floats indicating RGB channels. |
| | """ |
| | if isinstance(colors, str): |
| | return colors |
| | elif isinstance(colors, tuple): |
| | assert len(colors) == 3 |
| | for channel in colors: |
| | assert 0 <= channel <= 255 |
| | colors = [channel / 255 for channel in colors] |
| | return tuple(colors) |
| | elif isinstance(colors, list): |
| | colors = [ |
| | color_val_matplotlib(color) |
| | for color in colors |
| | ] |
| | return colors |
| | else: |
| | raise TypeError(f'Invalid type for color: {type(colors)}') |
| |
|
| |
|
| | def color_str2rgb(color: str) -> tuple: |
| | """Convert Matplotlib str color to an RGB color which range is 0 to 255, |
| | silently dropping the alpha channel. |
| | |
| | Args: |
| | color (str): Matplotlib color. |
| | |
| | Returns: |
| | tuple: RGB color. |
| | """ |
| | import matplotlib |
| | rgb_color: tuple = matplotlib.colors.to_rgb(color) |
| | rgb_color = tuple(int(c * 255) for c in rgb_color) |
| | return rgb_color |
| |
|
| |
|
| | def convert_overlay_heatmap(feat_map: Union[np.ndarray, torch.Tensor], |
| | img: Optional[np.ndarray] = None, |
| | alpha: float = 0.5) -> np.ndarray: |
| | """Convert feat_map to heatmap and overlay on image, if image is not None. |
| | |
| | Args: |
| | feat_map (np.ndarray, torch.Tensor): The feat_map to convert |
| | with of shape (H, W), where H is the image height and W is |
| | the image width. |
| | img (np.ndarray, optional): The origin image. The format |
| | should be RGB. Defaults to None. |
| | alpha (float): The transparency of featmap. Defaults to 0.5. |
| | |
| | Returns: |
| | np.ndarray: heatmap |
| | """ |
| | assert feat_map.ndim == 2 or (feat_map.ndim == 3 |
| | and feat_map.shape[0] in [1, 3]) |
| | if isinstance(feat_map, torch.Tensor): |
| | feat_map = feat_map.detach().cpu().numpy() |
| |
|
| | if feat_map.ndim == 3: |
| | feat_map = feat_map.transpose(1, 2, 0) |
| |
|
| | norm_img = np.zeros(feat_map.shape) |
| | norm_img = cv2.normalize(feat_map, norm_img, 0, 255, cv2.NORM_MINMAX) |
| | norm_img = np.asarray(norm_img, dtype=np.uint8) |
| | heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET) |
| | heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB) |
| | if img is not None: |
| | heat_img = cv2.addWeighted(img, 1 - alpha, heat_img, alpha, 0) |
| | return heat_img |
| |
|
| |
|
| | def wait_continue(figure, timeout: float = 0, continue_key: str = ' ') -> int: |
| | """Show the image and wait for the user's input. |
| | |
| | This implementation refers to |
| | https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py |
| | |
| | Args: |
| | timeout (float): If positive, continue after ``timeout`` seconds. |
| | Defaults to 0. |
| | continue_key (str): The key for users to continue. Defaults to |
| | the space key. |
| | |
| | Returns: |
| | int: If zero, means time out or the user pressed ``continue_key``, |
| | and if one, means the user closed the show figure. |
| | """ |
| | import matplotlib.pyplot as plt |
| | from matplotlib.backend_bases import CloseEvent |
| | is_inline = 'inline' in plt.get_backend() |
| | if is_inline: |
| | |
| | return 0 |
| |
|
| | if figure.canvas.manager: |
| | |
| | figure.show() |
| |
|
| | while True: |
| |
|
| | |
| | event = None |
| |
|
| | def handler(ev): |
| | |
| | nonlocal event |
| | |
| | |
| | event = ev if not isinstance(event, CloseEvent) else event |
| | figure.canvas.stop_event_loop() |
| |
|
| | cids = [ |
| | figure.canvas.mpl_connect(name, handler) |
| | for name in ('key_press_event', 'close_event') |
| | ] |
| |
|
| | try: |
| | figure.canvas.start_event_loop(timeout) |
| | finally: |
| | |
| | for cid in cids: |
| | figure.canvas.mpl_disconnect(cid) |
| |
|
| | if isinstance(event, CloseEvent): |
| | return 1 |
| | elif event is None or event.key == continue_key: |
| | return 0 |
| |
|
| |
|
| | def img_from_canvas(canvas: 'FigureCanvasAgg') -> np.ndarray: |
| | """Get RGB image from ``FigureCanvasAgg``. |
| | |
| | Args: |
| | canvas (FigureCanvasAgg): The canvas to get image. |
| | |
| | Returns: |
| | np.ndarray: the output of image in RGB. |
| | """ |
| | s, (width, height) = canvas.print_to_buffer() |
| | buffer = np.frombuffer(s, dtype='uint8') |
| | img_rgba = buffer.reshape(height, width, 4) |
| | rgb, alpha = np.split(img_rgba, [3], axis=2) |
| | return rgb.astype('uint8') |
| |
|