rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame
3.76 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Tuple, Union
import mmcv
import numpy as np
from mmengine.utils import is_str
def palette_val(palette: List[tuple]) -> List[tuple]:
"""Convert palette to matplotlib palette.
Args:
palette (List[tuple]): A list of color tuples.
Returns:
List[tuple[float]]: A list of RGB matplotlib color tuples.
"""
new_palette = []
for color in palette:
color = [c / 255 for c in color]
new_palette.append(tuple(color))
return new_palette
def get_palette(palette: Union[List[tuple], str, tuple],
num_classes: int) -> List[Tuple[int]]:
"""Get palette from various inputs.
Args:
palette (list[tuple] | str | tuple): palette inputs.
num_classes (int): the number of classes.
Returns:
list[tuple[int]]: A list of color tuples.
"""
assert isinstance(num_classes, int)
if isinstance(palette, list):
dataset_palette = palette
elif isinstance(palette, tuple):
dataset_palette = [palette] * num_classes
elif palette == 'random' or palette is None:
state = np.random.get_state()
# random color
np.random.seed(42)
palette = np.random.randint(0, 256, size=(num_classes, 3))
np.random.set_state(state)
dataset_palette = [tuple(c) for c in palette]
elif palette == 'coco':
from mmdet.datasets import CocoDataset, CocoPanopticDataset
dataset_palette = CocoDataset.METAINFO['palette']
if len(dataset_palette) < num_classes:
dataset_palette = CocoPanopticDataset.METAINFO['palette']
elif palette == 'citys':
from mmdet.datasets import CityscapesDataset
dataset_palette = CityscapesDataset.METAINFO['palette']
elif palette == 'voc':
from mmdet.datasets import VOCDataset
dataset_palette = VOCDataset.METAINFO['palette']
elif is_str(palette):
dataset_palette = [mmcv.color_val(palette)[::-1]] * num_classes
else:
raise TypeError(f'Invalid type for palette: {type(palette)}')
assert len(dataset_palette) >= num_classes, \
'The length of palette should not be less than `num_classes`.'
return dataset_palette
def _get_adaptive_scales(areas: np.ndarray,
min_area: int = 800,
max_area: int = 30000) -> np.ndarray:
"""Get adaptive scales according to areas.
The scale range is [0.5, 1.0]. When the area is less than
``min_area``, the scale is 0.5 while the area is larger than
``max_area``, the scale is 1.0.
Args:
areas (ndarray): The areas of bboxes or masks with the
shape of (n, ).
min_area (int): Lower bound areas for adaptive scales.
Defaults to 800.
max_area (int): Upper bound areas for adaptive scales.
Defaults to 30000.
Returns:
ndarray: The adaotive scales with the shape of (n, ).
"""
scales = 0.5 + (areas - min_area) // (max_area - min_area)
scales = np.clip(scales, 0.5, 1.0)
return scales
def jitter_color(color: tuple) -> tuple:
"""Randomly jitter the given color in order to better distinguish instances
with the same class.
Args:
color (tuple): The RGB color tuple. Each value is between [0, 255].
Returns:
tuple: The jittered color tuple.
"""
jitter = np.random.rand(3)
jitter = (jitter / np.linalg.norm(jitter) - 0.5) * 0.5 * 255
color = np.clip(jitter + color, 0, 255).astype(np.uint8)
return tuple(color)