File size: 3,739 Bytes
186701e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import math
import sys
from argparse import ArgumentParser
from pathlib import Path

import cv2
import onnxruntime
from config import (CLASS_COLORS, CLASS_NAMES, ModelType, YOLOv5_ANCHORS,
                    YOLOv7_ANCHORS)
from cv2_nms import non_max_suppression
from numpy_coder import Decoder
from preprocess import Preprocess
from tqdm import tqdm

# Add __FILE__  to sys.path
sys.path.append(str(Path(__file__).resolve().parents[0]))

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
                  '.tiff', '.webp')


def path_to_list(path: str):
    path = Path(path)
    if path.is_file() and path.suffix in IMG_EXTENSIONS:
        res_list = [str(path.absolute())]
    elif path.is_dir():
        res_list = [
            str(p.absolute()) for p in path.iterdir()
            if p.suffix in IMG_EXTENSIONS
        ]
    else:
        raise RuntimeError
    return res_list


def parse_args():
    parser = ArgumentParser()
    parser.add_argument(
        'img', help='Image path, include image file, dir and URL.')
    parser.add_argument('onnx', type=str, help='Onnx file')
    parser.add_argument('--type', type=str, help='Model type')
    parser.add_argument(
        '--img-size',
        nargs='+',
        type=int,
        default=[640, 640],
        help='Image size of height and width')
    parser.add_argument(
        '--out-dir', default='./output', type=str, help='Path to output file')
    parser.add_argument(
        '--show', action='store_true', help='Show the detection results')
    parser.add_argument(
        '--score-thr', type=float, default=0.3, help='Bbox score threshold')
    parser.add_argument(
        '--iou-thr', type=float, default=0.7, help='Bbox iou threshold')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    out_dir = Path(args.out_dir)
    model_type = ModelType(args.type.lower())

    if not args.show:
        out_dir.mkdir(parents=True, exist_ok=True)

    files = path_to_list(args.img)
    session = onnxruntime.InferenceSession(
        args.onnx, providers=['CPUExecutionProvider'])
    preprocessor = Preprocess(model_type)
    decoder = Decoder(model_type, model_only=True)
    if model_type == ModelType.YOLOV5:
        anchors = YOLOv5_ANCHORS
    elif model_type == ModelType.YOLOV7:
        anchors = YOLOv7_ANCHORS
    else:
        anchors = None

    for file in tqdm(files):
        image = cv2.imread(file)
        image_h, image_w = image.shape[:2]
        img, (ratio_w, ratio_h) = preprocessor(image, args.img_size)
        features = session.run(None, {'images': img})
        decoder_outputs = decoder(
            features,
            args.score_thr,
            num_labels=len(CLASS_NAMES),
            anchors=anchors)
        nmsd_boxes, nmsd_scores, nmsd_labels = non_max_suppression(
            *decoder_outputs, args.score_thr, args.iou_thr)
        for box, score, label in zip(nmsd_boxes, nmsd_scores, nmsd_labels):
            x0, y0, x1, y1 = box
            x0 = math.floor(min(max(x0 / ratio_w, 1), image_w - 1))
            y0 = math.floor(min(max(y0 / ratio_h, 1), image_h - 1))
            x1 = math.ceil(min(max(x1 / ratio_w, 1), image_w - 1))
            y1 = math.ceil(min(max(y1 / ratio_h, 1), image_h - 1))
            cv2.rectangle(image, (x0, y0), (x1, y1), CLASS_COLORS[label], 2)
            cv2.putText(image, f'{CLASS_NAMES[label]}: {score:.2f}',
                        (x0, y0 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                        (0, 255, 255), 2)
        if args.show:
            cv2.imshow('result', image)
            cv2.waitKey(0)
        else:
            cv2.imwrite(f'{out_dir / Path(file).name}', image)


if __name__ == '__main__':
    main()