Spaces:
Runtime error
Runtime error
File size: 6,152 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import numpy as np
import torch
from mmcv.image import tensor2imgs
from mmcv.parallel import DataContainer
from mmdet.core import encode_mask_results
from .utils import tensor2grayimgs
def retrieve_img_tensor_and_meta(data):
"""Retrieval img_tensor, img_metas and img_norm_cfg.
Args:
data (dict): One batch data from data_loader.
Returns:
tuple: Returns (img_tensor, img_metas, img_norm_cfg).
- | img_tensor (Tensor): Input image tensor with shape
:math:`(N, C, H, W)`.
- | img_metas (list[dict]): The metadata of images.
- | img_norm_cfg (dict): Config for image normalization.
"""
if isinstance(data['img'], torch.Tensor):
# for textrecog with batch_size > 1
# and not use 'DefaultFormatBundle' in pipeline
img_tensor = data['img']
img_metas = data['img_metas'].data[0]
elif isinstance(data['img'], list):
if isinstance(data['img'][0], torch.Tensor):
# for textrecog with aug_test and batch_size = 1
img_tensor = data['img'][0]
elif isinstance(data['img'][0], DataContainer):
# for textdet with 'MultiScaleFlipAug'
# and 'DefaultFormatBundle' in pipeline
img_tensor = data['img'][0].data[0]
img_metas = data['img_metas'][0].data[0]
elif isinstance(data['img'], DataContainer):
# for textrecog with 'DefaultFormatBundle' in pipeline
img_tensor = data['img'].data[0]
img_metas = data['img_metas'].data[0]
must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape', 'ori_shape']
for key in must_keys:
if key not in img_metas[0]:
raise KeyError(
f'Please add {key} to the "meta_keys" in the pipeline')
img_norm_cfg = img_metas[0]['img_norm_cfg']
if max(img_norm_cfg['mean']) <= 1:
img_norm_cfg['mean'] = [255 * x for x in img_norm_cfg['mean']]
img_norm_cfg['std'] = [255 * x for x in img_norm_cfg['std']]
return img_tensor, img_metas, img_norm_cfg
def single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
is_kie=False,
show_score_thr=0.3):
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for data in data_loader:
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
batch_size = len(result)
if show or out_dir:
if is_kie:
img_tensor = data['img'].data[0]
if img_tensor.shape[0] != 1:
raise KeyError('Visualizing KIE outputs in batches is'
'currently not supported.')
gt_bboxes = data['gt_bboxes'].data[0]
img_metas = data['img_metas'].data[0]
must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape']
for key in must_keys:
if key not in img_metas[0]:
raise KeyError(
f'Please add {key} to the "meta_keys" in config.')
# for no visual model
if np.prod(img_tensor.shape) == 0:
imgs = []
for img_meta in img_metas:
try:
img = mmcv.imread(img_meta['filename'])
except Exception as e:
print(f'Load image with error: {e}, '
'use empty image instead.')
img = np.ones(
img_meta['img_shape'], dtype=np.uint8)
imgs.append(img)
else:
imgs = tensor2imgs(img_tensor,
**img_metas[0]['img_norm_cfg'])
for i, img in enumerate(imgs):
h, w, _ = img_metas[i]['img_shape']
img_show = img[:h, :w, :]
if out_dir:
out_file = osp.join(out_dir,
img_metas[i]['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result[i],
gt_bboxes[i],
show=show,
out_file=out_file)
else:
img_tensor, img_metas, img_norm_cfg = \
retrieve_img_tensor_and_meta(data)
if img_tensor.size(1) == 1:
imgs = tensor2grayimgs(img_tensor, **img_norm_cfg)
else:
imgs = tensor2imgs(img_tensor, **img_norm_cfg)
assert len(imgs) == len(img_metas)
for j, (img, img_meta) in enumerate(zip(imgs, img_metas)):
img_shape, ori_shape = img_meta['img_shape'], img_meta[
'ori_shape']
img_show = img[:img_shape[0], :img_shape[1]]
img_show = mmcv.imresize(img_show,
(ori_shape[1], ori_shape[0]))
if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result[j],
show=show,
out_file=out_file,
score_thr=show_score_thr)
# encode mask results
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
results.extend(result)
for _ in range(batch_size):
prog_bar.update()
return results
|