# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Meta Platforms, Inc. All Rights Reserved import numpy as np import torch from detectron2.data import MetadataCatalog from detectron2.engine.defaults import DefaultPredictor from detectron2.utils.visualizer import ColorMode, Visualizer class OVSegPredictor(DefaultPredictor): def __init__(self, cfg): super().__init__(cfg) def __call__(self, original_image, class_names): """ Args: original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). Returns: predictions (dict): the output of the model for one image only. See :doc:`/tutorials/models` for details about the format. """ with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 # Apply pre-processing to image. if self.input_format == "RGB": # whether the model expects BGR inputs or RGB original_image = original_image[:, :, ::-1] height, width = original_image.shape[:2] image = self.aug.get_transform(original_image).apply_image(original_image) image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) inputs = {"image": image, "height": height, "width": width, "class_names": class_names} predictions = self.model([inputs])[0] return predictions class OVSegVisualizer(Visualizer): def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE, class_names=None): super().__init__(img_rgb, metadata, scale, instance_mode) self.class_names = class_names def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8): """ Draw semantic segmentation predictions/labels. Args: sem_seg (Tensor or ndarray): the segmentation of shape (H, W). Each value is the integer label of the pixel. area_threshold (int): segments with less than `area_threshold` are not drawn. alpha (float): the larger it is, the more opaque the segmentations are. Returns: output (VisImage): image object with visualizations. """ if isinstance(sem_seg, torch.Tensor): sem_seg = sem_seg.numpy() labels, areas = np.unique(sem_seg, return_counts=True) sorted_idxs = np.argsort(-areas).tolist() labels = labels[sorted_idxs] class_names = self.class_names if self.class_names is not None else self.metadata.stuff_classes for label in filter(lambda l: l < len(class_names), labels): try: mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] except (AttributeError, IndexError): mask_color = None binary_mask = (sem_seg == label).astype(np.uint8) text = class_names[label] self.draw_binary_mask( binary_mask, color=mask_color, edge_color=(1.0, 1.0, 240.0 / 255), text=text, alpha=alpha, area_threshold=area_threshold, ) return self.output class VisualizationDemo(object): def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): """ Args: cfg (CfgNode): instance_mode (ColorMode): parallel (bool): whether to run the model in different processes from visualization. Useful since the visualization logic can be slow. """ self.metadata = MetadataCatalog.get( cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" ) self.cpu_device = torch.device("cpu") self.instance_mode = instance_mode self.parallel = parallel if parallel: raise NotImplementedError else: self.predictor = OVSegPredictor(cfg) def run_on_image(self, image, class_names): """ Args: image (np.ndarray): an image of shape (H, W, C) (in BGR order). This is the format used by OpenCV. Returns: predictions (dict): the output of the model. vis_output (VisImage): the visualized image output. """ predictions = self.predictor(image, class_names) # Convert image from OpenCV BGR format to Matplotlib RGB format. image = image[:, :, ::-1] visualizer = OVSegVisualizer(image, self.metadata, instance_mode=self.instance_mode, class_names=class_names) if "sem_seg" in predictions: r = predictions["sem_seg"] blank_area = (r[0] == 0) pred_mask = r.argmax(dim=0).to('cpu') pred_mask[blank_area] = 255 pred_mask = np.array(pred_mask, dtype=np.int) vis_output = visualizer.draw_sem_seg( pred_mask ) else: raise NotImplementedError return predictions, vis_output