ov-seg / open_vocab_seg /utils /predictor.py
liangfeng
add ovseg
583456e
raw history blame
No virus
5.09 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import numpy as np
import torch
from detectron2.data import MetadataCatalog
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.visualizer import ColorMode, Visualizer
class OVSegPredictor(DefaultPredictor):
def __init__(self, cfg):
super().__init__(cfg)
def __call__(self, original_image, class_names):
"""
Args:
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
Returns:
predictions (dict):
the output of the model for one image only.
See :doc:`/tutorials/models` for details about the format.
"""
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
# Apply pre-processing to image.
if self.input_format == "RGB":
# whether the model expects BGR inputs or RGB
original_image = original_image[:, :, ::-1]
height, width = original_image.shape[:2]
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs = {"image": image, "height": height, "width": width, "class_names": class_names}
predictions = self.model([inputs])[0]
return predictions
class OVSegVisualizer(Visualizer):
def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE, class_names=None):
super().__init__(img_rgb, metadata, scale, instance_mode)
self.class_names = class_names
def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8):
"""
Draw semantic segmentation predictions/labels.
Args:
sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
Each value is the integer label of the pixel.
area_threshold (int): segments with less than `area_threshold` are not drawn.
alpha (float): the larger it is, the more opaque the segmentations are.
Returns:
output (VisImage): image object with visualizations.
"""
if isinstance(sem_seg, torch.Tensor):
sem_seg = sem_seg.numpy()
labels, areas = np.unique(sem_seg, return_counts=True)
sorted_idxs = np.argsort(-areas).tolist()
labels = labels[sorted_idxs]
class_names = self.class_names if self.class_names is not None else self.metadata.stuff_classes
for label in filter(lambda l: l < len(class_names), labels):
try:
mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
except (AttributeError, IndexError):
mask_color = None
binary_mask = (sem_seg == label).astype(np.uint8)
text = class_names[label]
self.draw_binary_mask(
binary_mask,
color=mask_color,
edge_color=(1.0, 1.0, 240.0 / 255),
text=text,
alpha=alpha,
area_threshold=area_threshold,
)
return self.output
class VisualizationDemo(object):
def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
"""
Args:
cfg (CfgNode):
instance_mode (ColorMode):
parallel (bool): whether to run the model in different processes from visualization.
Useful since the visualization logic can be slow.
"""
self.metadata = MetadataCatalog.get(
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
)
self.cpu_device = torch.device("cpu")
self.instance_mode = instance_mode
self.parallel = parallel
if parallel:
raise NotImplementedError
else:
self.predictor = OVSegPredictor(cfg)
def run_on_image(self, image, class_names):
"""
Args:
image (np.ndarray): an image of shape (H, W, C) (in BGR order).
This is the format used by OpenCV.
Returns:
predictions (dict): the output of the model.
vis_output (VisImage): the visualized image output.
"""
predictions = self.predictor(image, class_names)
# Convert image from OpenCV BGR format to Matplotlib RGB format.
image = image[:, :, ::-1]
visualizer = OVSegVisualizer(image, self.metadata, instance_mode=self.instance_mode, class_names=class_names)
if "sem_seg" in predictions:
r = predictions["sem_seg"]
blank_area = (r[0] == 0)
pred_mask = r.argmax(dim=0).to('cpu')
pred_mask[blank_area] = 255
pred_mask = np.array(pred_mask, dtype=np.int)
vis_output = visualizer.draw_sem_seg(
pred_mask
)
else:
raise NotImplementedError
return predictions, vis_output