Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
import mmcv | |
import numpy as np | |
import torch | |
from mmcv.image import tensor2imgs | |
from mmcv.parallel import DataContainer | |
from mmdet.core import encode_mask_results | |
from .utils import tensor2grayimgs | |
def retrieve_img_tensor_and_meta(data): | |
"""Retrieval img_tensor, img_metas and img_norm_cfg. | |
Args: | |
data (dict): One batch data from data_loader. | |
Returns: | |
tuple: Returns (img_tensor, img_metas, img_norm_cfg). | |
- | img_tensor (Tensor): Input image tensor with shape | |
:math:`(N, C, H, W)`. | |
- | img_metas (list[dict]): The metadata of images. | |
- | img_norm_cfg (dict): Config for image normalization. | |
""" | |
if isinstance(data['img'], torch.Tensor): | |
# for textrecog with batch_size > 1 | |
# and not use 'DefaultFormatBundle' in pipeline | |
img_tensor = data['img'] | |
img_metas = data['img_metas'].data[0] | |
elif isinstance(data['img'], list): | |
if isinstance(data['img'][0], torch.Tensor): | |
# for textrecog with aug_test and batch_size = 1 | |
img_tensor = data['img'][0] | |
elif isinstance(data['img'][0], DataContainer): | |
# for textdet with 'MultiScaleFlipAug' | |
# and 'DefaultFormatBundle' in pipeline | |
img_tensor = data['img'][0].data[0] | |
img_metas = data['img_metas'][0].data[0] | |
elif isinstance(data['img'], DataContainer): | |
# for textrecog with 'DefaultFormatBundle' in pipeline | |
img_tensor = data['img'].data[0] | |
img_metas = data['img_metas'].data[0] | |
must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape', 'ori_shape'] | |
for key in must_keys: | |
if key not in img_metas[0]: | |
raise KeyError( | |
f'Please add {key} to the "meta_keys" in the pipeline') | |
img_norm_cfg = img_metas[0]['img_norm_cfg'] | |
if max(img_norm_cfg['mean']) <= 1: | |
img_norm_cfg['mean'] = [255 * x for x in img_norm_cfg['mean']] | |
img_norm_cfg['std'] = [255 * x for x in img_norm_cfg['std']] | |
return img_tensor, img_metas, img_norm_cfg | |
def single_gpu_test(model, | |
data_loader, | |
show=False, | |
out_dir=None, | |
is_kie=False, | |
show_score_thr=0.3): | |
model.eval() | |
results = [] | |
dataset = data_loader.dataset | |
prog_bar = mmcv.ProgressBar(len(dataset)) | |
for data in data_loader: | |
with torch.no_grad(): | |
result = model(return_loss=False, rescale=True, **data) | |
batch_size = len(result) | |
if show or out_dir: | |
if is_kie: | |
img_tensor = data['img'].data[0] | |
if img_tensor.shape[0] != 1: | |
raise KeyError('Visualizing KIE outputs in batches is' | |
'currently not supported.') | |
gt_bboxes = data['gt_bboxes'].data[0] | |
img_metas = data['img_metas'].data[0] | |
must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape'] | |
for key in must_keys: | |
if key not in img_metas[0]: | |
raise KeyError( | |
f'Please add {key} to the "meta_keys" in config.') | |
# for no visual model | |
if np.prod(img_tensor.shape) == 0: | |
imgs = [] | |
for img_meta in img_metas: | |
try: | |
img = mmcv.imread(img_meta['filename']) | |
except Exception as e: | |
print(f'Load image with error: {e}, ' | |
'use empty image instead.') | |
img = np.ones( | |
img_meta['img_shape'], dtype=np.uint8) | |
imgs.append(img) | |
else: | |
imgs = tensor2imgs(img_tensor, | |
**img_metas[0]['img_norm_cfg']) | |
for i, img in enumerate(imgs): | |
h, w, _ = img_metas[i]['img_shape'] | |
img_show = img[:h, :w, :] | |
if out_dir: | |
out_file = osp.join(out_dir, | |
img_metas[i]['ori_filename']) | |
else: | |
out_file = None | |
model.module.show_result( | |
img_show, | |
result[i], | |
gt_bboxes[i], | |
show=show, | |
out_file=out_file) | |
else: | |
img_tensor, img_metas, img_norm_cfg = \ | |
retrieve_img_tensor_and_meta(data) | |
if img_tensor.size(1) == 1: | |
imgs = tensor2grayimgs(img_tensor, **img_norm_cfg) | |
else: | |
imgs = tensor2imgs(img_tensor, **img_norm_cfg) | |
assert len(imgs) == len(img_metas) | |
for j, (img, img_meta) in enumerate(zip(imgs, img_metas)): | |
img_shape, ori_shape = img_meta['img_shape'], img_meta[ | |
'ori_shape'] | |
img_show = img[:img_shape[0], :img_shape[1]] | |
img_show = mmcv.imresize(img_show, | |
(ori_shape[1], ori_shape[0])) | |
if out_dir: | |
out_file = osp.join(out_dir, img_meta['ori_filename']) | |
else: | |
out_file = None | |
model.module.show_result( | |
img_show, | |
result[j], | |
show=show, | |
out_file=out_file, | |
score_thr=show_score_thr) | |
# encode mask results | |
if isinstance(result[0], tuple): | |
result = [(bbox_results, encode_mask_results(mask_results)) | |
for bbox_results, mask_results in result] | |
results.extend(result) | |
for _ in range(batch_size): | |
prog_bar.update() | |
return results | |