File size: 4,968 Bytes
186701e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) OpenMMLab. All rights reserved.
from projects.easydeploy.model import ORTWrapper, TRTWrapper  # isort:skip
import os
import random
from argparse import ArgumentParser

import cv2
import mmcv
import numpy as np
import torch
from mmcv.transforms import Compose
from mmdet.utils import get_test_pipeline_cfg
from mmengine.config import Config, ConfigDict
from mmengine.utils import ProgressBar, path

from mmyolo.utils import register_all_modules
from mmyolo.utils.misc import get_file_list


def parse_args():
    parser = ArgumentParser()
    parser.add_argument(
        'img', help='Image path, include image file, dir and URL.')
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument(
        '--out-dir', default='./output', help='Path to output file')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--show', action='store_true', help='Show the detection results')
    args = parser.parse_args()
    return args


def preprocess(config):
    data_preprocess = config.get('model', {}).get('data_preprocessor', {})
    mean = data_preprocess.get('mean', [0., 0., 0.])
    std = data_preprocess.get('std', [1., 1., 1.])
    mean = torch.tensor(mean, dtype=torch.float32).reshape(1, 3, 1, 1)
    std = torch.tensor(std, dtype=torch.float32).reshape(1, 3, 1, 1)

    class PreProcess(torch.nn.Module):

        def __init__(self):
            super().__init__()

        def forward(self, x):
            x = x[None].float()
            x -= mean.to(x.device)
            x /= std.to(x.device)
            return x

    return PreProcess().eval()


def main():
    args = parse_args()

    # register all modules in mmdet into the registries
    register_all_modules()

    colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(1000)]

    # build the model from a config file and a checkpoint file
    if args.checkpoint.endswith('.onnx'):
        model = ORTWrapper(args.checkpoint, args.device)
    elif args.checkpoint.endswith('.engine') or args.checkpoint.endswith(
            '.plan'):
        model = TRTWrapper(args.checkpoint, args.device)
    else:
        raise NotImplementedError

    model.to(args.device)

    cfg = Config.fromfile(args.config)
    class_names = cfg.get('class_name')

    test_pipeline = get_test_pipeline_cfg(cfg)
    test_pipeline[0] = ConfigDict({'type': 'mmdet.LoadImageFromNDArray'})
    test_pipeline = Compose(test_pipeline)

    pre_pipeline = preprocess(cfg)

    if not args.show:
        path.mkdir_or_exist(args.out_dir)

    # get file list
    files, source_type = get_file_list(args.img)

    # start detector inference
    progress_bar = ProgressBar(len(files))
    for i, file in enumerate(files):
        bgr = mmcv.imread(file)
        rgb = mmcv.imconvert(bgr, 'bgr', 'rgb')
        data, samples = test_pipeline(dict(img=rgb, img_id=i)).values()
        pad_param = samples.get('pad_param',
                                np.array([0, 0, 0, 0], dtype=np.float32))
        h, w = samples.get('ori_shape', rgb.shape[:2])
        pad_param = torch.asarray(
            [pad_param[2], pad_param[0], pad_param[2], pad_param[0]],
            device=args.device)
        scale_factor = samples.get('scale_factor', [1., 1])
        scale_factor = torch.asarray(scale_factor * 2, device=args.device)
        data = pre_pipeline(data).to(args.device)

        result = model(data)
        if source_type['is_dir']:
            filename = os.path.relpath(file, args.img).replace('/', '_')
        else:
            filename = os.path.basename(file)
        out_file = None if args.show else os.path.join(args.out_dir, filename)

        # Get candidate predict info by num_dets
        num_dets, bboxes, scores, labels = result
        scores = scores[0, :num_dets]
        bboxes = bboxes[0, :num_dets]
        labels = labels[0, :num_dets]
        bboxes -= pad_param
        bboxes /= scale_factor

        bboxes[:, 0::2].clamp_(0, w)
        bboxes[:, 1::2].clamp_(0, h)
        bboxes = bboxes.round().int()

        for (bbox, score, label) in zip(bboxes, scores, labels):
            bbox = bbox.tolist()
            color = colors[label]

            if class_names is not None:
                label_name = class_names[label]
                name = f'cls:{label_name}_score:{score:0.4f}'
            else:
                name = f'cls:{label}_score:{score:0.4f}'

            cv2.rectangle(bgr, bbox[:2], bbox[2:], color, 2)
            cv2.putText(
                bgr,
                name, (bbox[0], bbox[1] - 2),
                cv2.FONT_HERSHEY_SIMPLEX,
                2.0, [225, 255, 255],
                thickness=3)

        if args.show:
            mmcv.imshow(bgr, 'result', 0)
        else:
            mmcv.imwrite(bgr, out_file)
        progress_bar.update()


if __name__ == '__main__':
    main()