Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
from projects.easydeploy.model import ORTWrapper, TRTWrapper # isort:skip | |
import os | |
import random | |
from argparse import ArgumentParser | |
import cv2 | |
import mmcv | |
import numpy as np | |
import torch | |
from mmcv.transforms import Compose | |
from mmdet.utils import get_test_pipeline_cfg | |
from mmengine.config import Config, ConfigDict | |
from mmengine.utils import ProgressBar, path | |
from mmyolo.utils import register_all_modules | |
from mmyolo.utils.misc import get_file_list | |
def parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument( | |
'img', help='Image path, include image file, dir and URL.') | |
parser.add_argument('config', help='Config file') | |
parser.add_argument('checkpoint', help='Checkpoint file') | |
parser.add_argument( | |
'--out-dir', default='./output', help='Path to output file') | |
parser.add_argument( | |
'--device', default='cuda:0', help='Device used for inference') | |
parser.add_argument( | |
'--show', action='store_true', help='Show the detection results') | |
args = parser.parse_args() | |
return args | |
def preprocess(config): | |
data_preprocess = config.get('model', {}).get('data_preprocessor', {}) | |
mean = data_preprocess.get('mean', [0., 0., 0.]) | |
std = data_preprocess.get('std', [1., 1., 1.]) | |
mean = torch.tensor(mean, dtype=torch.float32).reshape(1, 3, 1, 1) | |
std = torch.tensor(std, dtype=torch.float32).reshape(1, 3, 1, 1) | |
class PreProcess(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
x = x[None].float() | |
x -= mean.to(x.device) | |
x /= std.to(x.device) | |
return x | |
return PreProcess().eval() | |
def main(): | |
args = parse_args() | |
# register all modules in mmdet into the registries | |
register_all_modules() | |
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(1000)] | |
# build the model from a config file and a checkpoint file | |
if args.checkpoint.endswith('.onnx'): | |
model = ORTWrapper(args.checkpoint, args.device) | |
elif args.checkpoint.endswith('.engine') or args.checkpoint.endswith( | |
'.plan'): | |
model = TRTWrapper(args.checkpoint, args.device) | |
else: | |
raise NotImplementedError | |
model.to(args.device) | |
cfg = Config.fromfile(args.config) | |
class_names = cfg.get('class_name') | |
test_pipeline = get_test_pipeline_cfg(cfg) | |
test_pipeline[0] = ConfigDict({'type': 'mmdet.LoadImageFromNDArray'}) | |
test_pipeline = Compose(test_pipeline) | |
pre_pipeline = preprocess(cfg) | |
if not args.show: | |
path.mkdir_or_exist(args.out_dir) | |
# get file list | |
files, source_type = get_file_list(args.img) | |
# start detector inference | |
progress_bar = ProgressBar(len(files)) | |
for i, file in enumerate(files): | |
bgr = mmcv.imread(file) | |
rgb = mmcv.imconvert(bgr, 'bgr', 'rgb') | |
data, samples = test_pipeline(dict(img=rgb, img_id=i)).values() | |
pad_param = samples.get('pad_param', | |
np.array([0, 0, 0, 0], dtype=np.float32)) | |
h, w = samples.get('ori_shape', rgb.shape[:2]) | |
pad_param = torch.asarray( | |
[pad_param[2], pad_param[0], pad_param[2], pad_param[0]], | |
device=args.device) | |
scale_factor = samples.get('scale_factor', [1., 1]) | |
scale_factor = torch.asarray(scale_factor * 2, device=args.device) | |
data = pre_pipeline(data).to(args.device) | |
result = model(data) | |
if source_type['is_dir']: | |
filename = os.path.relpath(file, args.img).replace('/', '_') | |
else: | |
filename = os.path.basename(file) | |
out_file = None if args.show else os.path.join(args.out_dir, filename) | |
# Get candidate predict info by num_dets | |
num_dets, bboxes, scores, labels = result | |
scores = scores[0, :num_dets] | |
bboxes = bboxes[0, :num_dets] | |
labels = labels[0, :num_dets] | |
bboxes -= pad_param | |
bboxes /= scale_factor | |
bboxes[:, 0::2].clamp_(0, w) | |
bboxes[:, 1::2].clamp_(0, h) | |
bboxes = bboxes.round().int() | |
for (bbox, score, label) in zip(bboxes, scores, labels): | |
bbox = bbox.tolist() | |
color = colors[label] | |
if class_names is not None: | |
label_name = class_names[label] | |
name = f'cls:{label_name}_score:{score:0.4f}' | |
else: | |
name = f'cls:{label}_score:{score:0.4f}' | |
cv2.rectangle(bgr, bbox[:2], bbox[2:], color, 2) | |
cv2.putText( | |
bgr, | |
name, (bbox[0], bbox[1] - 2), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
2.0, [225, 255, 255], | |
thickness=3) | |
if args.show: | |
mmcv.imshow(bgr, 'result', 0) | |
else: | |
mmcv.imwrite(bgr, out_file) | |
progress_bar.update() | |
if __name__ == '__main__': | |
main() | |