|
|
|
|
|
import torch |
|
|
|
from detectron2.engine.defaults import DefaultPredictor |
|
from detectron2.utils.visualizer import ColorMode, Visualizer |
|
|
|
|
|
class BatchDefaultPredictor(DefaultPredictor): |
|
def __call__(self, original_images): |
|
""" |
|
Args: |
|
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). |
|
|
|
Returns: |
|
predictions (dict): |
|
the output of the model for one image only. |
|
See :doc:`/tutorials/models` for details about the format. |
|
""" |
|
with torch.no_grad(): |
|
|
|
height, width = original_images.shape[1:3] |
|
batch_inputs = [] |
|
for original_image in original_images: |
|
image = self.aug.get_transform(original_image).apply_image(original_image) |
|
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) |
|
|
|
inputs = {"image": image, "height": height, "width": width} |
|
batch_inputs.append(inputs) |
|
predictions = self.model(batch_inputs)[0] |
|
return predictions |
|
|
|
class Visualizer_GRiT(Visualizer): |
|
def __init__(self, image, instance_mode=None): |
|
super().__init__(image, instance_mode=instance_mode) |
|
|
|
def draw_instance_predictions(self, predictions): |
|
boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None |
|
scores = predictions.scores if predictions.has("scores") else None |
|
classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None |
|
object_description = predictions.pred_object_descriptions.data |
|
|
|
|
|
|
|
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): |
|
colors = [ |
|
self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes |
|
] |
|
alpha = 0.8 |
|
else: |
|
colors = None |
|
alpha = 0.5 |
|
|
|
if self._instance_mode == ColorMode.IMAGE_BW: |
|
self.output.reset_image( |
|
self._create_grayscale_image( |
|
(predictions.pred_masks.any(dim=0) > 0).numpy() |
|
if predictions.has("pred_masks") |
|
else None |
|
) |
|
) |
|
alpha = 0.3 |
|
|
|
self.overlay_instances( |
|
masks=None, |
|
boxes=boxes, |
|
labels=object_description, |
|
keypoints=None, |
|
assigned_colors=colors, |
|
alpha=alpha, |
|
) |
|
return self.output |
|
|
|
|
|
class VisualizationDemo(object): |
|
def __init__(self, cfg, instance_mode=ColorMode.IMAGE): |
|
self.cpu_device = torch.device("cpu") |
|
self.instance_mode = instance_mode |
|
|
|
self.predictor = DefaultPredictor(cfg) |
|
|
|
def run_on_image(self, image): |
|
predictions = self.predictor(image) |
|
|
|
image = image[:, :, ::-1] |
|
visualizer = Visualizer_GRiT(image, instance_mode=self.instance_mode) |
|
instances = predictions["instances"].to(self.cpu_device) |
|
vis_output = visualizer.draw_instance_predictions(predictions=instances) |
|
|
|
return predictions, vis_output |