Spaces:
Running
Running
from mmengine.fileio import FileClient | |
from mmengine.dist import master_only | |
from einops import rearrange | |
import torch | |
import mmcv | |
import numpy as np | |
import os.path as osp | |
import cv2 | |
from typing import Optional, Sequence | |
import torch.nn as nn | |
from mmdet.apis import inference_detector | |
from mmcv.transforms import Compose | |
from mmdet.engine import DetVisualizationHook | |
from mmdet.registry import HOOKS | |
from mmdet.structures import DetDataSample | |
from utils.io_utils import find_all_imgs, square_pad_resize, imglist2grid | |
def inference_detector( | |
model: nn.Module, | |
imgs, | |
test_pipeline | |
): | |
if isinstance(imgs, (list, tuple)): | |
is_batch = True | |
else: | |
imgs = [imgs] | |
is_batch = False | |
if len(imgs) == 0: | |
return [] | |
test_pipeline = test_pipeline.copy() | |
if isinstance(imgs[0], np.ndarray): | |
# Calling this method across libraries will result | |
# in module unregistered error if not prefixed with mmdet. | |
test_pipeline[0].type = 'mmdet.LoadImageFromNDArray' | |
test_pipeline = Compose(test_pipeline) | |
result_list = [] | |
for img in imgs: | |
# prepare data | |
if isinstance(img, np.ndarray): | |
# TODO: remove img_id. | |
data_ = dict(img=img, img_id=0) | |
else: | |
# TODO: remove img_id. | |
data_ = dict(img_path=img, img_id=0) | |
# build the data pipeline | |
data_ = test_pipeline(data_) | |
data_['inputs'] = [data_['inputs']] | |
data_['data_samples'] = [data_['data_samples']] | |
# forward the model | |
with torch.no_grad(): | |
results = model.test_step(data_)[0] | |
result_list.append(results) | |
if not is_batch: | |
return result_list[0] | |
else: | |
return result_list | |
class InstanceSegVisualizationHook(DetVisualizationHook): | |
def __init__(self, visualize_samples: str = '', | |
read_rgb: bool = False, | |
draw: bool = False, | |
interval: int = 50, | |
score_thr: float = 0.3, | |
show: bool = False, | |
wait_time: float = 0., | |
test_out_dir: Optional[str] = None, | |
file_client_args: dict = dict(backend='disk')): | |
super().__init__(draw, interval, score_thr, show, wait_time, test_out_dir, file_client_args) | |
self.vis_samples = [] | |
if osp.exists(visualize_samples): | |
self.channel_order = channel_order = 'rgb' if read_rgb else 'bgr' | |
samples = find_all_imgs(visualize_samples, abs_path=True) | |
for imgp in samples: | |
img = mmcv.imread(imgp, channel_order=channel_order) | |
img, _, _, _ = square_pad_resize(img, 640) | |
self.vis_samples.append(img) | |
def before_val(self, runner) -> None: | |
total_curr_iter = runner.iter | |
self._visualize_data(total_curr_iter, runner) | |
return super().before_val(runner) | |
# def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, | |
# outputs: Sequence[DetDataSample]) -> None: | |
# """Run after every ``self.interval`` validation iterations. | |
# Args: | |
# runner (:obj:`Runner`): The runner of the validation process. | |
# batch_idx (int): The index of the current batch in the val loop. | |
# data_batch (dict): Data from dataloader. | |
# outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples | |
# that contain annotations and predictions. | |
# """ | |
# # if self.draw is False: | |
# # return | |
# if self.file_client is None: | |
# self.file_client = FileClient(**self.file_client_args) | |
# # There is no guarantee that the same batch of images | |
# # is visualized for each evaluation. | |
# total_curr_iter = runner.iter + batch_idx | |
# # # Visualize only the first data | |
# # img_path = outputs[0].img_path | |
# # img_bytes = self.file_client.get(img_path) | |
# # img = mmcv.imfrombytes(img_bytes, channel_order='rgb') | |
# if total_curr_iter % self.interval == 0 and self.vis_samples: | |
# self._visualize_data(total_curr_iter, runner) | |
def _visualize_data(self, total_curr_iter, runner): | |
tgt_size = 384 | |
runner.model.eval() | |
outputs = inference_detector(runner.model, self.vis_samples, test_pipeline=runner.cfg.test_pipeline) | |
vis_results = [] | |
for img, output in zip(self.vis_samples, outputs): | |
vis_img = self.add_datasample( | |
'val_img', | |
img, | |
data_sample=output, | |
show=self.show, | |
wait_time=self.wait_time, | |
pred_score_thr=self.score_thr, | |
draw_gt=False, | |
step=total_curr_iter) | |
vis_results.append(cv2.resize(vis_img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA)) | |
drawn_img = imglist2grid(vis_results, tgt_size) | |
if drawn_img is None: | |
return | |
drawn_img = cv2.cvtColor(drawn_img, cv2.COLOR_BGR2RGB) | |
visualizer = self._visualizer | |
visualizer.set_image(drawn_img) | |
visualizer.add_image('val_img', drawn_img, total_curr_iter) | |
def add_datasample( | |
self, | |
name: str, | |
image: np.ndarray, | |
data_sample: Optional['DetDataSample'] = None, | |
draw_gt: bool = True, | |
draw_pred: bool = True, | |
show: bool = False, | |
wait_time: float = 0, | |
# TODO: Supported in mmengine's Viusalizer. | |
out_file: Optional[str] = None, | |
pred_score_thr: float = 0.3, | |
step: int = 0) -> np.ndarray: | |
image = image.clip(0, 255).astype(np.uint8) | |
visualizer = self._visualizer | |
classes = visualizer.dataset_meta.get('classes', None) | |
palette = visualizer.dataset_meta.get('palette', None) | |
gt_img_data = None | |
pred_img_data = None | |
if data_sample is not None: | |
data_sample = data_sample.cpu() | |
if draw_gt and data_sample is not None: | |
gt_img_data = image | |
if 'gt_instances' in data_sample: | |
gt_img_data = visualizer._draw_instances(image, | |
data_sample.gt_instances, | |
classes, palette) | |
if 'gt_panoptic_seg' in data_sample: | |
assert classes is not None, 'class information is ' \ | |
'not provided when ' \ | |
'visualizing panoptic ' \ | |
'segmentation results.' | |
gt_img_data = visualizer._draw_panoptic_seg( | |
gt_img_data, data_sample.gt_panoptic_seg, classes) | |
if draw_pred and data_sample is not None: | |
pred_img_data = image | |
if 'pred_instances' in data_sample: | |
pred_instances = data_sample.pred_instances | |
pred_instances = pred_instances[ | |
pred_instances.scores > pred_score_thr] | |
pred_img_data = visualizer._draw_instances(image, pred_instances, | |
classes, palette) | |
if 'pred_panoptic_seg' in data_sample: | |
assert classes is not None, 'class information is ' \ | |
'not provided when ' \ | |
'visualizing panoptic ' \ | |
'segmentation results.' | |
pred_img_data = visualizer._draw_panoptic_seg( | |
pred_img_data, data_sample.pred_panoptic_seg.numpy(), | |
classes) | |
if gt_img_data is not None and pred_img_data is not None: | |
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) | |
elif gt_img_data is not None: | |
drawn_img = gt_img_data | |
elif pred_img_data is not None: | |
drawn_img = pred_img_data | |
else: | |
# Display the original image directly if nothing is drawn. | |
drawn_img = image | |
return drawn_img | |