zhengrongzhang
commited on
Commit
•
e6c79f4
1
Parent(s):
ae727cb
init model
Browse files- README.md +133 -0
- anchors.npy +3 -0
- coco.yaml +97 -0
- default.yaml +119 -0
- demo.jpg +0 -0
- general_json2yolo.py +183 -0
- onnx_eval.py +284 -0
- onnx_inference.py +148 -0
- requirements.txt +43 -0
- strides.npy +3 -0
- utils.py +2140 -0
- yolov8m_qat.onnx +3 -0
README.md
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
tags:
|
4 |
+
- RyzenAI
|
5 |
+
- object-detection
|
6 |
+
- vision
|
7 |
+
- YOLO
|
8 |
+
- Pytorch
|
9 |
+
datasets:
|
10 |
+
- COCO
|
11 |
+
metrics:
|
12 |
+
- mAP
|
13 |
+
---
|
14 |
+
# YOLOv8m model trained on COCO
|
15 |
+
|
16 |
+
YOLOv8m is the medium version of YOLOv8 model trained on COCO object detection (118k annotated images) at resolution 640x640. It was released in [https://github.com/ultralytics/ultralytics](https://github.com/ultralytics/ultralytics).
|
17 |
+
|
18 |
+
We develop a modified version that could be supported by [AMD Ryzen AI](https://onnxruntime.ai/docs/execution-providers/Vitis-AI-ExecutionProvider.html).
|
19 |
+
|
20 |
+
|
21 |
+
## Model description
|
22 |
+
|
23 |
+
Ultralytics YOLOv8 is a cutting-edge, state-of-the-art (SOTA) model that builds upon the success of previous YOLO versions and introduces new features and improvements to further boost performance and flexibility. YOLOv8 is designed to be fast, accurate, and easy to use, making it an excellent choice for a wide range of object detection and tracking, instance segmentation, image classification and pose estimation tasks.
|
24 |
+
|
25 |
+
|
26 |
+
## Intended uses & limitations
|
27 |
+
|
28 |
+
You can use the raw model for object detection. See the [model hub](https://huggingface.co/models?search=amd/yolov8) to look for all available YOLOv8 models.
|
29 |
+
|
30 |
+
|
31 |
+
## How to use
|
32 |
+
|
33 |
+
### Installation
|
34 |
+
|
35 |
+
Follow [Ryzen AI Installation](https://ryzenai.docs.amd.com/en/latest/inst.html) to prepare the environment for Ryzen AI.
|
36 |
+
Run the following script to install pre-requisites for this model.
|
37 |
+
```bash
|
38 |
+
pip install -r requirements.txt
|
39 |
+
```
|
40 |
+
|
41 |
+
|
42 |
+
### Data Preparation (optional: for accuracy evaluation)
|
43 |
+
|
44 |
+
The dataset MSCOCO2017 contains 118287 images for training and 5000 images for validation.
|
45 |
+
|
46 |
+
Download COCO dataset and create directories in your code like this:
|
47 |
+
```plain
|
48 |
+
└── datasets
|
49 |
+
└── coco
|
50 |
+
├── annotations
|
51 |
+
| ├── instances_val2017.json
|
52 |
+
| └── ...
|
53 |
+
├── labels
|
54 |
+
| ├── val2017
|
55 |
+
| | ├── 000000000139.txt
|
56 |
+
| ├── 000000000285.txt
|
57 |
+
| └── ...
|
58 |
+
├── images
|
59 |
+
| ├── val2017
|
60 |
+
| | ├── 000000000139.jpg
|
61 |
+
| ├── 000000000285.jpg
|
62 |
+
└── val2017.txt
|
63 |
+
```
|
64 |
+
1. put the val2017 image folder under images directory or use a softlink
|
65 |
+
2. the labels folder and val2017.txt above are generate by **general_json2yolo.py**
|
66 |
+
3. modify the coco.yaml like this:
|
67 |
+
```markdown
|
68 |
+
path: /path/to/your/datasets/coco # dataset root dir
|
69 |
+
train: train2017.txt # train images (relative to 'path') 118287 images
|
70 |
+
val: val2017.txt # val images (relative to 'path') 5000 images
|
71 |
+
```
|
72 |
+
|
73 |
+
|
74 |
+
### Test & Evaluation
|
75 |
+
|
76 |
+
- Code snippet from [`onnx_inference.py`](onnx_inference.py) on how to use
|
77 |
+
```python
|
78 |
+
args = make_parser().parse_args()
|
79 |
+
source = args.image_path
|
80 |
+
dataset = LoadImages(
|
81 |
+
source, imgsz=imgsz, stride=32, auto=False, transforms=None, vid_stride=1
|
82 |
+
)
|
83 |
+
onnx_weight = args.model
|
84 |
+
onnx_model = onnxruntime.InferenceSession(onnx_weight)
|
85 |
+
for batch in dataset:
|
86 |
+
path, im, im0s, vid_cap, s = batch
|
87 |
+
im = preprocess(im)
|
88 |
+
if len(im.shape) == 3:
|
89 |
+
im = im[None]
|
90 |
+
outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: im.cpu().numpy()})
|
91 |
+
outputs = [torch.tensor(item) for item in outputs]
|
92 |
+
preds = post_process(outputs)
|
93 |
+
preds = non_max_suppression(
|
94 |
+
preds, 0.25, 0.7, agnostic=False, max_det=300, classes=None
|
95 |
+
)
|
96 |
+
plot_images(
|
97 |
+
im,
|
98 |
+
*output_to_target(preds, max_det=15),
|
99 |
+
source,
|
100 |
+
fname=args.output_path,
|
101 |
+
names=names,
|
102 |
+
)
|
103 |
+
|
104 |
+
```
|
105 |
+
|
106 |
+
- Run inference for a single image
|
107 |
+
```python
|
108 |
+
python onnx_inference.py -m ./yolov8m_qat.onnx -i /Path/To/Your/Image --ipu --provider_config /Path/To/Your/Provider_config
|
109 |
+
```
|
110 |
+
*Note: __vaip_config.json__ is located at the setup package of Ryzen AI (refer to [Installation](#installation))*
|
111 |
+
- Test accuracy of the quantized model
|
112 |
+
```python
|
113 |
+
python onnx_eval.py -m ./yolov8m_qat.onnx --ipu --provider_config /Path/To/Your/Provider_config
|
114 |
+
```
|
115 |
+
|
116 |
+
### Performance
|
117 |
+
|
118 |
+
|Metric |Accuracy on IPU|
|
119 |
+
| :----: | :----: |
|
120 |
+
|AP\@0.50:0.95|0.486|
|
121 |
+
|
122 |
+
|
123 |
+
```bibtex
|
124 |
+
@software{yolov8_ultralytics,
|
125 |
+
author = {Glenn Jocher and Ayush Chaurasia and Jing Qiu},
|
126 |
+
title = {Ultralytics YOLOv8},
|
127 |
+
version = {8.0.0},
|
128 |
+
year = {2023},
|
129 |
+
url = {https://github.com/ultralytics/ultralytics},
|
130 |
+
orcid = {0000-0001-5950-6979, 0000-0002-7603-6750, 0000-0003-3783-7069},
|
131 |
+
license = {AGPL-3.0}
|
132 |
+
}
|
133 |
+
```
|
anchors.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b677c029271bca12b3e4311e468d316962e31d1a2b7bae3eed90c9c65738fa31
|
3 |
+
size 67328
|
coco.yaml
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, GPL-3.0 license
|
2 |
+
# COCO 2017 dataset http://cocodataset.org by Microsoft
|
3 |
+
# Example usage: python train.py --data coco.yaml
|
4 |
+
# parent
|
5 |
+
# ├── yolov5
|
6 |
+
# └── datasets
|
7 |
+
# └── coco ← downloads here (20.1 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ./datasets/coco # dataset root dir
|
12 |
+
train: train2017.txt # train images (relative to 'path') 118287 images
|
13 |
+
val: val2017.txt # val images (relative to 'path') 5000 images
|
14 |
+
#test: val2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: airplane
|
23 |
+
5: bus
|
24 |
+
6: train
|
25 |
+
7: truck
|
26 |
+
8: boat
|
27 |
+
9: traffic light
|
28 |
+
10: fire hydrant
|
29 |
+
11: stop sign
|
30 |
+
12: parking meter
|
31 |
+
13: bench
|
32 |
+
14: bird
|
33 |
+
15: cat
|
34 |
+
16: dog
|
35 |
+
17: horse
|
36 |
+
18: sheep
|
37 |
+
19: cow
|
38 |
+
20: elephant
|
39 |
+
21: bear
|
40 |
+
22: zebra
|
41 |
+
23: giraffe
|
42 |
+
24: backpack
|
43 |
+
25: umbrella
|
44 |
+
26: handbag
|
45 |
+
27: tie
|
46 |
+
28: suitcase
|
47 |
+
29: frisbee
|
48 |
+
30: skis
|
49 |
+
31: snowboard
|
50 |
+
32: sports ball
|
51 |
+
33: kite
|
52 |
+
34: baseball bat
|
53 |
+
35: baseball glove
|
54 |
+
36: skateboard
|
55 |
+
37: surfboard
|
56 |
+
38: tennis racket
|
57 |
+
39: bottle
|
58 |
+
40: wine glass
|
59 |
+
41: cup
|
60 |
+
42: fork
|
61 |
+
43: knife
|
62 |
+
44: spoon
|
63 |
+
45: bowl
|
64 |
+
46: banana
|
65 |
+
47: apple
|
66 |
+
48: sandwich
|
67 |
+
49: orange
|
68 |
+
50: broccoli
|
69 |
+
51: carrot
|
70 |
+
52: hot dog
|
71 |
+
53: pizza
|
72 |
+
54: donut
|
73 |
+
55: cake
|
74 |
+
56: chair
|
75 |
+
57: couch
|
76 |
+
58: potted plant
|
77 |
+
59: bed
|
78 |
+
60: dining table
|
79 |
+
61: toilet
|
80 |
+
62: tv
|
81 |
+
63: laptop
|
82 |
+
64: mouse
|
83 |
+
65: remote
|
84 |
+
66: keyboard
|
85 |
+
67: cell phone
|
86 |
+
68: microwave
|
87 |
+
69: oven
|
88 |
+
70: toaster
|
89 |
+
71: sink
|
90 |
+
72: refrigerator
|
91 |
+
73: book
|
92 |
+
74: clock
|
93 |
+
75: vase
|
94 |
+
76: scissors
|
95 |
+
77: teddy bear
|
96 |
+
78: hair drier
|
97 |
+
79: toothbrush
|
default.yaml
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, GPL-3.0 license
|
2 |
+
# Default training settings and hyperparameters for medium-augmentation COCO training
|
3 |
+
|
4 |
+
task: detect # inference task, i.e. detect, segment, classify
|
5 |
+
mode: train # YOLO mode, i.e. train, val, predict, export
|
6 |
+
|
7 |
+
# Train settings -------------------------------------------------------------------------------------------------------
|
8 |
+
model: # path to model file, i.e. yolov8n.pt, yolov8n.yaml
|
9 |
+
data: "./coco.yaml" # path to data file, i.e. i.e. coco128.yaml
|
10 |
+
epochs: 100 # number of epochs to train for
|
11 |
+
patience: 50 # epochs to wait for no observable improvement for early stopping of training
|
12 |
+
batch: 1 # number of images per batch (-1 for AutoBatch)
|
13 |
+
imgsz: 640 # size of input images as integer or w,h
|
14 |
+
save: True # save train checkpoints and predict results
|
15 |
+
cache: False # True/ram, disk or False. Use cache for data loading
|
16 |
+
device: # device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
|
17 |
+
workers: 8 # number of worker threads for data loading (per RANK if DDP)
|
18 |
+
project: # project name
|
19 |
+
name: # experiment name
|
20 |
+
exist_ok: False # whether to overwrite existing experiment
|
21 |
+
pretrained: False # whether to use a pretrained model
|
22 |
+
optimizer: SGD # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
23 |
+
verbose: True # whether to print verbose output
|
24 |
+
seed: 0 # random seed for reproducibility
|
25 |
+
deterministic: True # whether to enable deterministic mode
|
26 |
+
single_cls: False # train multi-class data as single-class
|
27 |
+
image_weights: False # use weighted image selection for training
|
28 |
+
rect: False # support rectangular training if mode='train', support rectangular evaluation if mode='val'
|
29 |
+
cos_lr: False # use cosine learning rate scheduler
|
30 |
+
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
|
31 |
+
resume: False # resume training from last checkpoint
|
32 |
+
min_memory: False # minimize memory footprint loss function, choices=[False, True, <roll_out_thr>]
|
33 |
+
sync_bn: False # convert batchnorm to syncbatchnorm in model
|
34 |
+
nndct_quant: False # True for quant model
|
35 |
+
quant_mode: 'test' # calib or test
|
36 |
+
dump_xmodel: False # True for dump xmodel
|
37 |
+
dump_onnx: False # True for dump onnx
|
38 |
+
onnx_weight: "./yolov8m_qat.onnx"
|
39 |
+
onnx_runtime: False
|
40 |
+
ipu: False
|
41 |
+
provider_config: ''
|
42 |
+
# Segmentation
|
43 |
+
overlap_mask: True # masks should overlap during training (segment train only)
|
44 |
+
mask_ratio: 4 # mask downsample ratio (segment train only)
|
45 |
+
# Classification
|
46 |
+
dropout: 0.0 # use dropout regularization (classify train only)
|
47 |
+
|
48 |
+
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
49 |
+
val: True # validate/test during training
|
50 |
+
save_json: False # save results to JSON file
|
51 |
+
save_hybrid: False # save hybrid version of labels (labels + additional predictions)
|
52 |
+
conf: # object confidence threshold for detection (default 0.25 predict, 0.001 val)
|
53 |
+
iou: 0.7 # intersection over union (IoU) threshold for NMS
|
54 |
+
max_det: 300 # maximum number of detections per image
|
55 |
+
half: False # use half precision (FP16)
|
56 |
+
dnn: False # use OpenCV DNN for ONNX inference
|
57 |
+
plots: True # save plots during train/val
|
58 |
+
|
59 |
+
# Prediction settings --------------------------------------------------------------------------------------------------
|
60 |
+
source: # source directory for images or videos
|
61 |
+
show: True # show results if possible
|
62 |
+
save_txt: True # save results as .txt file
|
63 |
+
save_conf: False # save results with confidence scores
|
64 |
+
save_crop: False # save cropped images with results
|
65 |
+
hide_labels: False # hide labels
|
66 |
+
hide_conf: False # hide confidence scores
|
67 |
+
vid_stride: 1 # video frame-rate stride
|
68 |
+
line_thickness: 3 # bounding box thickness (pixels)
|
69 |
+
visualize: False # visualize model features
|
70 |
+
augment: False # apply image augmentation to prediction sources
|
71 |
+
agnostic_nms: False # class-agnostic NMS
|
72 |
+
classes: # filter results by class, i.e. class=0, or class=[0,2,3]
|
73 |
+
retina_masks: False # use high-resolution segmentation masks
|
74 |
+
boxes: True # Show boxes in segmentation predictions
|
75 |
+
|
76 |
+
# Export settings ------------------------------------------------------------------------------------------------------
|
77 |
+
format: torchscript # format to export to
|
78 |
+
keras: False # use Keras
|
79 |
+
optimize: False # TorchScript: optimize for mobile
|
80 |
+
int8: False # CoreML/TF INT8 quantization
|
81 |
+
dynamic: False # ONNX/TF/TensorRT: dynamic axes
|
82 |
+
simplify: False # ONNX: simplify model
|
83 |
+
opset: # ONNX: opset version (optional)
|
84 |
+
workspace: 4 # TensorRT: workspace size (GB)
|
85 |
+
nms: False # CoreML: add NMS
|
86 |
+
|
87 |
+
# Hyperparameters ------------------------------------------------------------------------------------------------------
|
88 |
+
lr0: 0.01 # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
|
89 |
+
lrf: 0.01 # final learning rate (lr0 * lrf)
|
90 |
+
momentum: 0.937 # SGD momentum/Adam beta1
|
91 |
+
weight_decay: 0.0005 # optimizer weight decay 5e-4
|
92 |
+
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
93 |
+
warmup_momentum: 0.8 # warmup initial momentum
|
94 |
+
warmup_bias_lr: 0.1 # warmup initial bias lr
|
95 |
+
box: 7.5 # box loss gain
|
96 |
+
cls: 0.5 # cls loss gain (scale with pixels)
|
97 |
+
dfl: 1.5 # dfl loss gain
|
98 |
+
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
|
99 |
+
label_smoothing: 0.0 # label smoothing (fraction)
|
100 |
+
nbs: 64 # nominal batch size
|
101 |
+
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
102 |
+
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
|
103 |
+
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
|
104 |
+
degrees: 0.0 # image rotation (+/- deg)
|
105 |
+
translate: 0.1 # image translation (+/- fraction)
|
106 |
+
scale: 0.5 # image scale (+/- gain)
|
107 |
+
shear: 0.0 # image shear (+/- deg)
|
108 |
+
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
109 |
+
flipud: 0.0 # image flip up-down (probability)
|
110 |
+
fliplr: 0.5 # image flip left-right (probability)
|
111 |
+
mosaic: 1.0 # image mosaic (probability)
|
112 |
+
mixup: 0.0 # image mixup (probability)
|
113 |
+
copy_paste: 0.0 # segment copy-paste (probability)
|
114 |
+
|
115 |
+
# Custom config.yaml ---------------------------------------------------------------------------------------------------
|
116 |
+
cfg: # for overriding defaults.yaml
|
117 |
+
|
118 |
+
# Debug, do not modify -------------------------------------------------------------------------------------------------
|
119 |
+
v5loader: False # use legacy YOLOv5 dataloader
|
demo.jpg
ADDED
general_json2yolo.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from collections import defaultdict
|
3 |
+
from pathlib import Path
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
import sys
|
7 |
+
import pathlib
|
8 |
+
|
9 |
+
CURRENT_DIR = pathlib.Path(__file__).parent
|
10 |
+
sys.path.append(str(CURRENT_DIR))
|
11 |
+
|
12 |
+
|
13 |
+
def make_dirs(dir="./datasets/coco"):
|
14 |
+
# Create folders
|
15 |
+
dir = Path(dir)
|
16 |
+
for p in [dir / "labels"]:
|
17 |
+
p.mkdir(parents=True, exist_ok=True) # make dir
|
18 |
+
return dir
|
19 |
+
|
20 |
+
|
21 |
+
def coco91_to_coco80_class(): # converts 80-index (val2014) to 91-index (paper)
|
22 |
+
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
|
23 |
+
x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None,
|
24 |
+
None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
|
25 |
+
51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
|
26 |
+
None, 73, 74, 75, 76, 77, 78, 79, None]
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
def convert_coco_json(
|
31 |
+
json_dir="../coco/annotations/", use_segments=False, cls91to80=False
|
32 |
+
):
|
33 |
+
save_dir = make_dirs() # output directory
|
34 |
+
coco80 = coco91_to_coco80_class()
|
35 |
+
|
36 |
+
# Import json
|
37 |
+
for json_file in sorted(Path(json_dir).resolve().glob("*.json")):
|
38 |
+
if not str(json_file).endswith("instances_val2017.json"):
|
39 |
+
continue
|
40 |
+
fn = (
|
41 |
+
Path(save_dir) / "labels" / json_file.stem.replace("instances_", "")
|
42 |
+
) # folder name
|
43 |
+
fn.mkdir()
|
44 |
+
with open(json_file) as f:
|
45 |
+
data = json.load(f)
|
46 |
+
|
47 |
+
# Create image dict
|
48 |
+
images = {"%g" % x["id"]: x for x in data["images"]}
|
49 |
+
# Create image-annotations dict
|
50 |
+
imgToAnns = defaultdict(list)
|
51 |
+
for ann in data["annotations"]:
|
52 |
+
imgToAnns[ann["image_id"]].append(ann)
|
53 |
+
|
54 |
+
txt_file = open(Path(save_dir / "val2017").with_suffix(".txt"), "a")
|
55 |
+
# Write labels file
|
56 |
+
for img_id, anns in tqdm(imgToAnns.items(), desc=f"Annotations {json_file}"):
|
57 |
+
img = images["%g" % img_id]
|
58 |
+
h, w, f = img["height"], img["width"], img["file_name"]
|
59 |
+
bboxes = []
|
60 |
+
segments = []
|
61 |
+
|
62 |
+
txt_file.write(
|
63 |
+
"./images/" + "/".join(img["coco_url"].split("/")[-2:]) + "\n"
|
64 |
+
)
|
65 |
+
for ann in anns:
|
66 |
+
if ann["iscrowd"]:
|
67 |
+
continue
|
68 |
+
# The COCO box format is [top left x, top left y, width, height]
|
69 |
+
box = np.array(ann["bbox"], dtype=np.float64)
|
70 |
+
box[:2] += box[2:] / 2 # xy top-left corner to center
|
71 |
+
box[[0, 2]] /= w # normalize x
|
72 |
+
box[[1, 3]] /= h # normalize y
|
73 |
+
if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0
|
74 |
+
continue
|
75 |
+
|
76 |
+
cls = (
|
77 |
+
coco80[ann["category_id"] - 1]
|
78 |
+
if cls91to80
|
79 |
+
else ann["category_id"] - 1
|
80 |
+
) # class
|
81 |
+
box = [cls] + box.tolist()
|
82 |
+
if box not in bboxes:
|
83 |
+
bboxes.append(box)
|
84 |
+
# Segments
|
85 |
+
if use_segments:
|
86 |
+
if len(ann["segmentation"]) > 1:
|
87 |
+
s = merge_multi_segment(ann["segmentation"])
|
88 |
+
s = (
|
89 |
+
(np.concatenate(s, axis=0) / np.array([w, h]))
|
90 |
+
.reshape(-1)
|
91 |
+
.tolist()
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
s = [
|
95 |
+
j for i in ann["segmentation"] for j in i
|
96 |
+
] # all segments concatenated
|
97 |
+
s = (
|
98 |
+
(np.array(s).reshape(-1, 2) / np.array([w, h]))
|
99 |
+
.reshape(-1)
|
100 |
+
.tolist()
|
101 |
+
)
|
102 |
+
s = [cls] + s
|
103 |
+
if s not in segments:
|
104 |
+
segments.append(s)
|
105 |
+
|
106 |
+
# Write
|
107 |
+
with open((fn / f).with_suffix(".txt"), "a") as file:
|
108 |
+
for i in range(len(bboxes)):
|
109 |
+
line = (
|
110 |
+
*(segments[i] if use_segments else bboxes[i]),
|
111 |
+
) # cls, box or segments
|
112 |
+
file.write(("%g " * len(line)).rstrip() % line + "\n")
|
113 |
+
txt_file.close()
|
114 |
+
|
115 |
+
|
116 |
+
def min_index(arr1, arr2):
|
117 |
+
"""Find a pair of indexes with the shortest distance.
|
118 |
+
Args:
|
119 |
+
arr1: (N, 2).
|
120 |
+
arr2: (M, 2).
|
121 |
+
Return:
|
122 |
+
a pair of indexes(tuple).
|
123 |
+
"""
|
124 |
+
dis = ((arr1[:, None, :] - arr2[None, :, :]) ** 2).sum(-1)
|
125 |
+
return np.unravel_index(np.argmin(dis, axis=None), dis.shape)
|
126 |
+
|
127 |
+
|
128 |
+
def merge_multi_segment(segments):
|
129 |
+
"""Merge multi segments to one list.
|
130 |
+
Find the coordinates with min distance between each segment,
|
131 |
+
then connect these coordinates with one thin line to merge all
|
132 |
+
segments into one.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
segments(List(List)): original segmentations in coco's json file.
|
136 |
+
like [segmentation1, segmentation2,...],
|
137 |
+
each segmentation is a list of coordinates.
|
138 |
+
"""
|
139 |
+
s = []
|
140 |
+
segments = [np.array(i).reshape(-1, 2) for i in segments]
|
141 |
+
idx_list = [[] for _ in range(len(segments))]
|
142 |
+
|
143 |
+
# record the indexes with min distance between each segment
|
144 |
+
for i in range(1, len(segments)):
|
145 |
+
idx1, idx2 = min_index(segments[i - 1], segments[i])
|
146 |
+
idx_list[i - 1].append(idx1)
|
147 |
+
idx_list[i].append(idx2)
|
148 |
+
|
149 |
+
# use two round to connect all the segments
|
150 |
+
for k in range(2):
|
151 |
+
# forward connection
|
152 |
+
if k == 0:
|
153 |
+
for i, idx in enumerate(idx_list):
|
154 |
+
# middle segments have two indexes
|
155 |
+
# reverse the index of middle segments
|
156 |
+
if len(idx) == 2 and idx[0] > idx[1]:
|
157 |
+
idx = idx[::-1]
|
158 |
+
segments[i] = segments[i][::-1, :]
|
159 |
+
|
160 |
+
segments[i] = np.roll(segments[i], -idx[0], axis=0)
|
161 |
+
segments[i] = np.concatenate([segments[i], segments[i][:1]])
|
162 |
+
# deal with the first segment and the last one
|
163 |
+
if i in [0, len(idx_list) - 1]:
|
164 |
+
s.append(segments[i])
|
165 |
+
else:
|
166 |
+
idx = [0, idx[1] - idx[0]]
|
167 |
+
s.append(segments[i][idx[0] : idx[1] + 1])
|
168 |
+
|
169 |
+
else:
|
170 |
+
for i in range(len(idx_list) - 1, -1, -1):
|
171 |
+
if i not in [0, len(idx_list) - 1]:
|
172 |
+
idx = idx_list[i]
|
173 |
+
nidx = abs(idx[1] - idx[0])
|
174 |
+
s.append(segments[i][nidx:])
|
175 |
+
return s
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == "__main__":
|
179 |
+
convert_coco_json(
|
180 |
+
"./datasets/coco/annotations", # directory with *.json
|
181 |
+
use_segments=True,
|
182 |
+
cls91to80=True,
|
183 |
+
)
|
onnx_eval.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
import onnxruntime
|
8 |
+
from utils import check_det_dataset, yaml_load, IterableSimpleNamespace, build_dataloader, post_process, xyxy2xywh, LOGGER, \
|
9 |
+
DetMetrics, increment_path, get_cfg, smart_inference_mode, box_iou, TQDM_BAR_FORMAT, scale_boxes, non_max_suppression, xywh2xyxy
|
10 |
+
|
11 |
+
# Default configuration
|
12 |
+
DEFAULT_CFG_DICT = yaml_load("./default.yaml")
|
13 |
+
for k, v in DEFAULT_CFG_DICT.items():
|
14 |
+
if isinstance(v, str) and v.lower() == 'none':
|
15 |
+
DEFAULT_CFG_DICT[k] = None
|
16 |
+
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
|
17 |
+
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
|
18 |
+
import sys
|
19 |
+
import pathlib
|
20 |
+
CURRENT_DIR = pathlib.Path(__file__).parent
|
21 |
+
sys.path.append(str(CURRENT_DIR))
|
22 |
+
|
23 |
+
|
24 |
+
class DetectionValidator:
|
25 |
+
|
26 |
+
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
27 |
+
self.dataloader = dataloader
|
28 |
+
self.pbar = pbar
|
29 |
+
self.logger = LOGGER
|
30 |
+
self.args = args
|
31 |
+
self.model = None
|
32 |
+
self.data = None
|
33 |
+
self.device = None
|
34 |
+
self.batch_i = None
|
35 |
+
self.speed = None
|
36 |
+
self.jdict = None
|
37 |
+
self.args.task = 'detect'
|
38 |
+
project = Path("./runs") / self.args.task
|
39 |
+
self.save_dir = save_dir or increment_path(Path(project),
|
40 |
+
exist_ok=True)
|
41 |
+
(self.save_dir / 'labels').mkdir(parents=True, exist_ok=True)
|
42 |
+
self.args.conf = 0.001 # default conf=0.001
|
43 |
+
self.is_coco = False
|
44 |
+
self.class_map = None
|
45 |
+
self.metrics = DetMetrics(save_dir=self.save_dir)
|
46 |
+
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
|
47 |
+
self.niou = self.iouv.numel()
|
48 |
+
|
49 |
+
@smart_inference_mode()
|
50 |
+
def __call__(self, trainer=None, model=None):
|
51 |
+
"""
|
52 |
+
Supports validation of a pre-trained model if passed or a model being trained
|
53 |
+
if trainer is passed (trainer gets priority).
|
54 |
+
"""
|
55 |
+
self.device = torch.device('cpu')
|
56 |
+
onnx_weight = self.args.onnx_weight
|
57 |
+
if isinstance(onnx_weight, list):
|
58 |
+
onnx_weight = onnx_weight[0]
|
59 |
+
if self.args.ipu:
|
60 |
+
providers = ["VitisAIExecutionProvider"]
|
61 |
+
provider_options = [{"config_file": self.args.provider_config}]
|
62 |
+
onnx_model = onnxruntime.InferenceSession(onnx_weight, providers=providers, provider_options=provider_options)
|
63 |
+
else:
|
64 |
+
onnx_model = onnxruntime.InferenceSession(onnx_weight)
|
65 |
+
self.data = check_det_dataset(self.args.data)
|
66 |
+
self.args.rect = False
|
67 |
+
self.dataloader = self.dataloader or self.get_dataloader(self.data.get("val") or self.data.get("test"), self.args.batch)
|
68 |
+
total = len(self.dataloader)
|
69 |
+
n_batches = len(self.dataloader)
|
70 |
+
desc = self.get_desc()
|
71 |
+
bar = tqdm(self.dataloader, desc, total, bar_format=TQDM_BAR_FORMAT)
|
72 |
+
self.init_metrics()
|
73 |
+
self.jdict = [] # empty before each val
|
74 |
+
|
75 |
+
for batch_i, batch in enumerate(bar):
|
76 |
+
self.batch_i = batch_i
|
77 |
+
# pre-process
|
78 |
+
batch = self.preprocess(batch)
|
79 |
+
|
80 |
+
# inference
|
81 |
+
outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: batch["img"].cpu().numpy()})
|
82 |
+
outputs = [torch.tensor(item).to(self.device) for item in outputs]
|
83 |
+
preds = post_process(outputs)
|
84 |
+
|
85 |
+
# pre-process predictions
|
86 |
+
preds = self.postprocess(preds)
|
87 |
+
self.update_metrics(preds, batch)
|
88 |
+
stats = self.get_stats()
|
89 |
+
self.print_results()
|
90 |
+
if self.args.save_json and self.jdict:
|
91 |
+
with open(str(self.save_dir / "predictions.json"), 'w') as f:
|
92 |
+
self.logger.info(f"Saving {f.name}...")
|
93 |
+
json.dump(self.jdict, f) # flatten and save
|
94 |
+
stats = self.eval_json(stats) # update stats
|
95 |
+
return stats
|
96 |
+
|
97 |
+
def get_dataloader(self, dataset_path, batch_size):
|
98 |
+
# calculate stride - check if model is initialized
|
99 |
+
return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=32, names=self.data['names'], mode="val")[0]
|
100 |
+
|
101 |
+
def get_desc(self):
|
102 |
+
return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)")
|
103 |
+
|
104 |
+
def init_metrics(self):
|
105 |
+
self.is_coco = True
|
106 |
+
self.class_map = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
107 |
+
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
108 |
+
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
|
109 |
+
self.args.save_json = True
|
110 |
+
self.nc = 80
|
111 |
+
classnames = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
|
112 |
+
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
113 |
+
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
114 |
+
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
115 |
+
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
116 |
+
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
117 |
+
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
118 |
+
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
119 |
+
'hair drier', 'toothbrush']
|
120 |
+
self.names = {k: classnames[k] for k in range(80)}
|
121 |
+
self.metrics.names = self.names
|
122 |
+
self.metrics.plot = True
|
123 |
+
self.seen = 0
|
124 |
+
self.jdict = []
|
125 |
+
self.stats = []
|
126 |
+
|
127 |
+
def preprocess(self, batch):
|
128 |
+
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
129 |
+
batch["img"] = batch["img"].float() / 255
|
130 |
+
for k in ["batch_idx", "cls", "bboxes"]:
|
131 |
+
batch[k] = batch[k].to(self.device)
|
132 |
+
|
133 |
+
nb = len(batch["img"])
|
134 |
+
self.lb = [torch.cat([batch["cls"], batch["bboxes"]], dim=-1)[batch["batch_idx"] == i]
|
135 |
+
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
|
136 |
+
|
137 |
+
return batch
|
138 |
+
|
139 |
+
def postprocess(self, preds):
|
140 |
+
preds = non_max_suppression(preds,
|
141 |
+
self.args.conf,
|
142 |
+
self.args.iou,
|
143 |
+
labels=self.lb,
|
144 |
+
multi_label=True,
|
145 |
+
agnostic=self.args.single_cls,
|
146 |
+
max_det=self.args.max_det)
|
147 |
+
return preds
|
148 |
+
|
149 |
+
def update_metrics(self, preds, batch):
|
150 |
+
# Metrics
|
151 |
+
for si, pred in enumerate(preds):
|
152 |
+
idx = batch["batch_idx"] == si
|
153 |
+
cls = batch["cls"][idx]
|
154 |
+
bbox = batch["bboxes"][idx]
|
155 |
+
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
156 |
+
shape = batch["ori_shape"][si]
|
157 |
+
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
158 |
+
self.seen += 1
|
159 |
+
|
160 |
+
if npr == 0:
|
161 |
+
if nl:
|
162 |
+
self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1)))
|
163 |
+
continue
|
164 |
+
|
165 |
+
# Predictions
|
166 |
+
if self.args.single_cls:
|
167 |
+
pred[:, 5] = 0
|
168 |
+
predn = pred.clone()
|
169 |
+
scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape,
|
170 |
+
ratio_pad=batch["ratio_pad"][si]) # native-space pred
|
171 |
+
|
172 |
+
# Evaluate
|
173 |
+
if nl:
|
174 |
+
height, width = batch["img"].shape[2:]
|
175 |
+
tbox = xywh2xyxy(bbox) * torch.tensor(
|
176 |
+
(width, height, width, height), device=self.device) # target boxes
|
177 |
+
scale_boxes(batch["img"][si].shape[1:], tbox, shape,
|
178 |
+
ratio_pad=batch["ratio_pad"][si]) # native-space labels
|
179 |
+
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
180 |
+
correct_bboxes = self._process_batch(predn, labelsn)
|
181 |
+
self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls)
|
182 |
+
|
183 |
+
# Save
|
184 |
+
if self.args.save_json:
|
185 |
+
self.pred_to_json(predn, batch["im_file"][si])
|
186 |
+
|
187 |
+
def _process_batch(self, detections, labels):
|
188 |
+
"""
|
189 |
+
Return correct prediction matrix
|
190 |
+
Arguments:
|
191 |
+
detections (array[N, 6]), x1, y1, x2, y2, conf, class
|
192 |
+
labels (array[M, 5]), class, x1, y1, x2, y2
|
193 |
+
Returns:
|
194 |
+
correct (array[N, 10]), for 10 IoU levels
|
195 |
+
"""
|
196 |
+
iou = box_iou(labels[:, 1:], detections[:, :4])
|
197 |
+
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
|
198 |
+
correct_class = labels[:, 0:1] == detections[:, 5]
|
199 |
+
for i in range(len(self.iouv)):
|
200 |
+
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
|
201 |
+
if x[0].shape[0]:
|
202 |
+
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
|
203 |
+
1).cpu().numpy() # [label, detect, iou]
|
204 |
+
if x[0].shape[0] > 1:
|
205 |
+
matches = matches[matches[:, 2].argsort()[::-1]]
|
206 |
+
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
207 |
+
# matches = matches[matches[:, 2].argsort()[::-1]]
|
208 |
+
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
209 |
+
correct[matches[:, 1].astype(int), i] = True
|
210 |
+
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
211 |
+
|
212 |
+
def pred_to_json(self, predn, filename):
|
213 |
+
stem = Path(filename).stem
|
214 |
+
image_id = int(stem) if stem.isnumeric() else stem
|
215 |
+
box = xyxy2xywh(predn[:, :4]) # xywh
|
216 |
+
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
217 |
+
for p, b in zip(predn.tolist(), box.tolist()):
|
218 |
+
self.jdict.append({
|
219 |
+
'image_id': image_id,
|
220 |
+
'category_id': self.class_map[int(p[5])],
|
221 |
+
'bbox': [round(x, 3) for x in b],
|
222 |
+
'score': round(p[4], 5)})
|
223 |
+
|
224 |
+
def get_stats(self):
|
225 |
+
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
226 |
+
if len(stats) and stats[0].any():
|
227 |
+
self.metrics.process(*stats)
|
228 |
+
self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class
|
229 |
+
return self.metrics.results_dict
|
230 |
+
|
231 |
+
def print_results(self):
|
232 |
+
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
|
233 |
+
self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
234 |
+
if self.nt_per_class.sum() == 0:
|
235 |
+
self.logger.warning(
|
236 |
+
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
|
237 |
+
|
238 |
+
# Print results per class
|
239 |
+
if self.args.verbose and self.nc > 1 and len(self.stats):
|
240 |
+
for i, c in enumerate(self.metrics.ap_class_index):
|
241 |
+
self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
|
242 |
+
|
243 |
+
|
244 |
+
def eval_json(self, stats):
|
245 |
+
if self.args.save_json and self.is_coco and len(self.jdict):
|
246 |
+
anno_json = Path(self.data['path']) / 'annotations/instances_val2017.json' # annotations
|
247 |
+
pred_json = self.save_dir / "predictions.json" # predictions
|
248 |
+
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
249 |
+
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
250 |
+
from pycocotools.coco import COCO # noqa
|
251 |
+
from pycocotools.cocoeval import COCOeval # noqa
|
252 |
+
# for x in anno_json, pred_json:
|
253 |
+
# assert x.is_file(), f"{x} file not found"
|
254 |
+
anno = COCO(str(anno_json)) # init annotations api
|
255 |
+
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
256 |
+
eval = COCOeval(anno, pred, 'bbox')
|
257 |
+
if self.is_coco:
|
258 |
+
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
259 |
+
eval.evaluate()
|
260 |
+
eval.accumulate()
|
261 |
+
eval.summarize()
|
262 |
+
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
263 |
+
except Exception as e:
|
264 |
+
self.logger.warning(f'pycocotools unable to run: {e}')
|
265 |
+
return stats
|
266 |
+
|
267 |
+
|
268 |
+
def parse_opt():
|
269 |
+
parser = argparse.ArgumentParser()
|
270 |
+
parser.add_argument('--ipu', action='store_true', help='flag for ryzen ai')
|
271 |
+
parser.add_argument('--provider_config', default='', type=str, help='provider config for ryzen ai')
|
272 |
+
parser.add_argument("-m", "--model", default="./yolov8m_qat.onnx", type=str, help='onnx_weight')
|
273 |
+
opt = parser.parse_args()
|
274 |
+
return opt
|
275 |
+
|
276 |
+
|
277 |
+
if __name__ == "__main__":
|
278 |
+
opt = parse_opt()
|
279 |
+
args = get_cfg(DEFAULT_CFG)
|
280 |
+
args.ipu = opt.ipu
|
281 |
+
args.onnx_weight = opt.model
|
282 |
+
args.provider_config = opt.provider_config
|
283 |
+
validator = DetectionValidator(args=args)
|
284 |
+
validator()
|
onnx_inference.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import onnxruntime
|
4 |
+
import numpy as np
|
5 |
+
import argparse
|
6 |
+
from utils import (
|
7 |
+
LoadImages,
|
8 |
+
non_max_suppression,
|
9 |
+
plot_images,
|
10 |
+
output_to_target,
|
11 |
+
)
|
12 |
+
import sys
|
13 |
+
import pathlib
|
14 |
+
CURRENT_DIR = pathlib.Path(__file__).parent
|
15 |
+
sys.path.append(str(CURRENT_DIR))
|
16 |
+
|
17 |
+
def preprocess(img):
|
18 |
+
img = torch.from_numpy(img)
|
19 |
+
img = img.float() # uint8 to fp16/32
|
20 |
+
img /= 255 # 0 - 255 to 0.0 - 1.0
|
21 |
+
return img
|
22 |
+
|
23 |
+
|
24 |
+
class DFL(nn.Module):
|
25 |
+
# Integral module of Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
|
26 |
+
def __init__(self, c1=16):
|
27 |
+
super().__init__()
|
28 |
+
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
29 |
+
x = torch.arange(c1, dtype=torch.float)
|
30 |
+
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
|
31 |
+
self.c1 = c1
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
b, c, a = x.shape # batch, channels, anchors
|
35 |
+
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(
|
36 |
+
b, 4, a
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
41 |
+
"""Transform distance(ltrb) to box(xywh or xyxy)."""
|
42 |
+
lt, rb = torch.split(distance, 2, dim)
|
43 |
+
x1y1 = anchor_points - lt
|
44 |
+
x2y2 = anchor_points + rb
|
45 |
+
if xywh:
|
46 |
+
c_xy = (x1y1 + x2y2) / 2
|
47 |
+
wh = x2y2 - x1y1
|
48 |
+
return torch.cat((c_xy, wh), dim) # xywh bbox
|
49 |
+
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
|
50 |
+
|
51 |
+
|
52 |
+
def post_process(x):
|
53 |
+
dfl = DFL(16)
|
54 |
+
anchors = torch.tensor(
|
55 |
+
np.load(
|
56 |
+
"./anchors.npy",
|
57 |
+
allow_pickle=True,
|
58 |
+
)
|
59 |
+
)
|
60 |
+
strides = torch.tensor(
|
61 |
+
np.load(
|
62 |
+
"./strides.npy",
|
63 |
+
allow_pickle=True,
|
64 |
+
)
|
65 |
+
)
|
66 |
+
box, cls = torch.cat([xi.view(x[0].shape[0], 144, -1) for xi in x], 2).split(
|
67 |
+
(16 * 4, 80), 1
|
68 |
+
)
|
69 |
+
dbox = dist2bbox(dfl(box), anchors.unsqueeze(0), xywh=True, dim=1) * strides
|
70 |
+
y = torch.cat((dbox, cls.sigmoid()), 1)
|
71 |
+
return y, x
|
72 |
+
|
73 |
+
|
74 |
+
def make_parser():
|
75 |
+
parser = argparse.ArgumentParser("onnxruntime inference sample")
|
76 |
+
parser.add_argument(
|
77 |
+
"-m",
|
78 |
+
"--model",
|
79 |
+
type=str,
|
80 |
+
default="./yolov8m_qat.onnx",
|
81 |
+
help="input your onnx model.",
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"-i",
|
85 |
+
"--image_path",
|
86 |
+
type=str,
|
87 |
+
default='./demo.jpg',
|
88 |
+
help="path to your input image.",
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"-o",
|
92 |
+
"--output_path",
|
93 |
+
type=str,
|
94 |
+
default='./demo_infer.jpg',
|
95 |
+
help="path to your output directory.",
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"--ipu", action='store_true', help='flag for ryzen ai'
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--provider_config", default='', type=str, help='provider config for ryzen ai'
|
102 |
+
)
|
103 |
+
return parser
|
104 |
+
|
105 |
+
classnames = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
|
106 |
+
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
107 |
+
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
108 |
+
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
109 |
+
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
110 |
+
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
111 |
+
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
112 |
+
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
113 |
+
'hair drier', 'toothbrush']
|
114 |
+
names = {k: classnames[k] for k in range(80)}
|
115 |
+
imgsz = [640, 640]
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == '__main__':
|
119 |
+
args = make_parser().parse_args()
|
120 |
+
source = args.image_path
|
121 |
+
dataset = LoadImages(
|
122 |
+
source, imgsz=imgsz, stride=32, auto=False, transforms=None, vid_stride=1
|
123 |
+
)
|
124 |
+
onnx_weight = args.model
|
125 |
+
if args.ipu:
|
126 |
+
providers = ["VitisAIExecutionProvider"]
|
127 |
+
provider_options = [{"config_file": args.provider_config}]
|
128 |
+
onnx_model = onnxruntime.InferenceSession(onnx_weight, providers=providers, provider_options=provider_options)
|
129 |
+
else:
|
130 |
+
onnx_model = onnxruntime.InferenceSession(onnx_weight)
|
131 |
+
for batch in dataset:
|
132 |
+
path, im, im0s, vid_cap, s = batch
|
133 |
+
im = preprocess(im)
|
134 |
+
if len(im.shape) == 3:
|
135 |
+
im = im[None]
|
136 |
+
outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: im.cpu().numpy()})
|
137 |
+
outputs = [torch.tensor(item) for item in outputs]
|
138 |
+
preds = post_process(outputs)
|
139 |
+
preds = non_max_suppression(
|
140 |
+
preds, 0.25, 0.7, agnostic=False, max_det=300, classes=None
|
141 |
+
)
|
142 |
+
plot_images(
|
143 |
+
im,
|
144 |
+
*output_to_target(preds, max_det=15),
|
145 |
+
source,
|
146 |
+
fname=args.output_path,
|
147 |
+
names=names,
|
148 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics requirements
|
2 |
+
# Usage: pip install -r requirements.txt
|
3 |
+
|
4 |
+
# Base ----------------------------------------
|
5 |
+
matplotlib>=3.2.2
|
6 |
+
numpy>=1.18.5
|
7 |
+
opencv-python>=4.6.0
|
8 |
+
Pillow>=7.1.2
|
9 |
+
PyYAML>=5.3.1
|
10 |
+
requests>=2.23.0
|
11 |
+
scipy>=1.4.1
|
12 |
+
torch>=1.7.0
|
13 |
+
torchvision>=0.8.1
|
14 |
+
tqdm>=4.64.0
|
15 |
+
|
16 |
+
# Logging -------------------------------------
|
17 |
+
tensorboard>=2.4.1
|
18 |
+
# clearml
|
19 |
+
# comet
|
20 |
+
|
21 |
+
# Plotting ------------------------------------
|
22 |
+
pandas>=1.1.4
|
23 |
+
seaborn>=0.11.0
|
24 |
+
|
25 |
+
# Export --------------------------------------
|
26 |
+
# onnxruntime
|
27 |
+
# coremltools>=6.0 # CoreML export
|
28 |
+
onnx>=1.12.0 # ONNX export
|
29 |
+
# onnx-simplifier>=0.4.1 # ONNX simplifier
|
30 |
+
# nvidia-pyindex # TensorRT export
|
31 |
+
# nvidia-tensorrt # TensorRT export
|
32 |
+
# scikit-learn==0.19.2 # CoreML quantization
|
33 |
+
# tensorflow>=2.4.1 # TF exports (-cpu, -aarch64, -macos)
|
34 |
+
# tensorflowjs>=3.9.0 # TF.js export
|
35 |
+
# openvino-dev>=2022.3 # OpenVINO export
|
36 |
+
|
37 |
+
# Extras --------------------------------------
|
38 |
+
ipython # interactive notebook
|
39 |
+
psutil # system utilization
|
40 |
+
thop>=0.1.1 # FLOPs computation
|
41 |
+
# albumentations>=1.0.3
|
42 |
+
pycocotools>=2.0.6 # COCO mAP
|
43 |
+
# roboflow
|
strides.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21fa445f4e0732b0213026bb865551d7986579ea19b21d81763c218be0368b9e
|
3 |
+
size 33728
|
utils.py
ADDED
@@ -0,0 +1,2140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import os
|
3 |
+
import contextlib
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from PIL import Image, ImageDraw, ImageFont, ExifTags
|
7 |
+
from PIL import __version__ as pil_version
|
8 |
+
from multiprocessing.pool import ThreadPool
|
9 |
+
import numpy as np
|
10 |
+
from itertools import repeat
|
11 |
+
import glob
|
12 |
+
import cv2
|
13 |
+
import tempfile
|
14 |
+
import hashlib
|
15 |
+
from pathlib import Path
|
16 |
+
import time
|
17 |
+
import torchvision
|
18 |
+
import math
|
19 |
+
import re
|
20 |
+
from typing import List, Union, Dict
|
21 |
+
import pkg_resources as pkg
|
22 |
+
from types import SimpleNamespace
|
23 |
+
from torch.utils.data import Dataset, DataLoader
|
24 |
+
from tqdm import tqdm
|
25 |
+
import random
|
26 |
+
import yaml
|
27 |
+
import logging.config
|
28 |
+
import sys
|
29 |
+
import pathlib
|
30 |
+
CURRENT_DIR = pathlib.Path(__file__).parent
|
31 |
+
sys.path.append(str(CURRENT_DIR))
|
32 |
+
|
33 |
+
LOGGING_NAME = 'ultralytics'
|
34 |
+
LOGGER = logging.getLogger(LOGGING_NAME)
|
35 |
+
for fn in LOGGER.info, LOGGER.warning:
|
36 |
+
setattr(LOGGER, fn.__name__, lambda x: fn(x))
|
37 |
+
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
|
38 |
+
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
|
39 |
+
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
|
40 |
+
NUM_THREADS = min(8, os.cpu_count())
|
41 |
+
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
42 |
+
_formats = ["xyxy", "xywh", "ltwh"]
|
43 |
+
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'}
|
44 |
+
CFG_FRACTION_KEYS = {
|
45 |
+
'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'fl_gamma',
|
46 |
+
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic',
|
47 |
+
'mixup', 'copy_paste', 'conf', 'iou'}
|
48 |
+
CFG_INT_KEYS = {
|
49 |
+
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
50 |
+
'line_thickness', 'workspace', 'nbs'}
|
51 |
+
CFG_BOOL_KEYS = {
|
52 |
+
'save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr',
|
53 |
+
'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf',
|
54 |
+
'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
|
55 |
+
'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader'}
|
56 |
+
# Get orientation exif tag
|
57 |
+
for orientation in ExifTags.TAGS.keys():
|
58 |
+
if ExifTags.TAGS[orientation] == 'Orientation':
|
59 |
+
break
|
60 |
+
|
61 |
+
def segments2boxes(segments):
|
62 |
+
"""
|
63 |
+
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
|
64 |
+
|
65 |
+
Args:
|
66 |
+
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
(np.ndarray): the xywh coordinates of the bounding boxes.
|
70 |
+
"""
|
71 |
+
boxes = []
|
72 |
+
for s in segments:
|
73 |
+
x, y = s.T # segment xy
|
74 |
+
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
|
75 |
+
return xyxy2xywh(np.array(boxes)) # cls, xywh
|
76 |
+
|
77 |
+
|
78 |
+
def check_version(
|
79 |
+
current: str = "0.0.0",
|
80 |
+
minimum: str = "0.0.0",
|
81 |
+
name: str = "version ",
|
82 |
+
pinned: bool = False,
|
83 |
+
hard: bool = False,
|
84 |
+
verbose: bool = False,
|
85 |
+
) -> bool:
|
86 |
+
"""
|
87 |
+
Check current version against the required minimum version.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
current (str): Current version.
|
91 |
+
minimum (str): Required minimum version.
|
92 |
+
name (str): Name to be used in warning message.
|
93 |
+
pinned (bool): If True, versions must match exactly. If False, minimum version must be satisfied.
|
94 |
+
hard (bool): If True, raise an AssertionError if the minimum version is not met.
|
95 |
+
verbose (bool): If True, print warning message if minimum version is not met.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
bool: True if minimum version is met, False otherwise.
|
99 |
+
"""
|
100 |
+
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
101 |
+
result = (current == minimum) if pinned else (current >= minimum) # bool
|
102 |
+
warning_message = f"WARNING ⚠️ {name}{minimum} is required by YOLOv8, but {name}{current} is currently installed"
|
103 |
+
if verbose and not result:
|
104 |
+
LOGGER.warning(warning_message)
|
105 |
+
return result
|
106 |
+
|
107 |
+
|
108 |
+
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
|
109 |
+
|
110 |
+
|
111 |
+
def smart_inference_mode():
|
112 |
+
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
|
113 |
+
def decorate(fn):
|
114 |
+
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
|
115 |
+
|
116 |
+
return decorate
|
117 |
+
|
118 |
+
|
119 |
+
def box_iou(box1, box2, eps=1e-7):
|
120 |
+
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
121 |
+
"""
|
122 |
+
Return intersection-over-union (Jaccard index) of boxes.
|
123 |
+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
124 |
+
Arguments:
|
125 |
+
box1 (Tensor[N, 4])
|
126 |
+
box2 (Tensor[M, 4])
|
127 |
+
Returns:
|
128 |
+
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
129 |
+
IoU values for every element in boxes1 and boxes2
|
130 |
+
"""
|
131 |
+
|
132 |
+
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
133 |
+
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
|
134 |
+
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
|
135 |
+
|
136 |
+
# IoU = inter / (area1 + area2 - inter)
|
137 |
+
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
|
138 |
+
|
139 |
+
|
140 |
+
class LoadImages:
|
141 |
+
# YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
|
142 |
+
def __init__(
|
143 |
+
self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1
|
144 |
+
):
|
145 |
+
# *.txt file with img/vid/dir on each line
|
146 |
+
if isinstance(path, str) and Path(path).suffix == ".txt":
|
147 |
+
path = Path(path).read_text().rsplit()
|
148 |
+
files = []
|
149 |
+
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
150 |
+
p = str(Path(p).resolve())
|
151 |
+
if "*" in p:
|
152 |
+
files.extend(sorted(glob.glob(p, recursive=True))) # glob
|
153 |
+
elif os.path.isdir(p):
|
154 |
+
files.extend(sorted(glob.glob(os.path.join(p, "*.*")))) # dir
|
155 |
+
elif os.path.isfile(p):
|
156 |
+
files.append(p) # files
|
157 |
+
else:
|
158 |
+
raise FileNotFoundError(f"{p} does not exist")
|
159 |
+
# include image suffixes
|
160 |
+
images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
|
161 |
+
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
|
162 |
+
ni, nv = len(images), len(videos)
|
163 |
+
|
164 |
+
self.imgsz = imgsz
|
165 |
+
self.stride = stride
|
166 |
+
self.files = images + videos
|
167 |
+
self.nf = ni + nv # number of files
|
168 |
+
self.video_flag = [False] * ni + [True] * nv
|
169 |
+
self.mode = "image"
|
170 |
+
self.auto = auto
|
171 |
+
self.transforms = transforms # optional
|
172 |
+
self.vid_stride = vid_stride # video frame-rate stride
|
173 |
+
self.bs = 1
|
174 |
+
if any(videos):
|
175 |
+
self.orientation = None # rotation degrees
|
176 |
+
self._new_video(videos[0]) # new video
|
177 |
+
else:
|
178 |
+
self.cap = None
|
179 |
+
if self.nf == 0:
|
180 |
+
raise FileNotFoundError(
|
181 |
+
f"No images or videos found in {p}. "
|
182 |
+
f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
183 |
+
)
|
184 |
+
|
185 |
+
def __iter__(self):
|
186 |
+
self.count = 0
|
187 |
+
return self
|
188 |
+
|
189 |
+
def __next__(self):
|
190 |
+
if self.count == self.nf:
|
191 |
+
raise StopIteration
|
192 |
+
path = self.files[self.count]
|
193 |
+
|
194 |
+
if self.video_flag[self.count]:
|
195 |
+
# Read video
|
196 |
+
self.mode = "video"
|
197 |
+
for _ in range(self.vid_stride):
|
198 |
+
self.cap.grab()
|
199 |
+
success, im0 = self.cap.retrieve()
|
200 |
+
while not success:
|
201 |
+
self.count += 1
|
202 |
+
self.cap.release()
|
203 |
+
if self.count == self.nf: # last video
|
204 |
+
raise StopIteration
|
205 |
+
path = self.files[self.count]
|
206 |
+
self._new_video(path)
|
207 |
+
success, im0 = self.cap.read()
|
208 |
+
|
209 |
+
self.frame += 1
|
210 |
+
s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: "
|
211 |
+
|
212 |
+
else:
|
213 |
+
# Read image
|
214 |
+
self.count += 1
|
215 |
+
im0 = cv2.imread(path) # BGR
|
216 |
+
if im0 is None:
|
217 |
+
raise FileNotFoundError(f"Image Not Found {path}")
|
218 |
+
s = f"image {self.count}/{self.nf} {path}: "
|
219 |
+
|
220 |
+
if self.transforms:
|
221 |
+
im = self.transforms(im0) # transforms
|
222 |
+
else:
|
223 |
+
im = LetterBox(self.imgsz, self.auto, stride=self.stride)(image=im0)
|
224 |
+
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
225 |
+
im = np.ascontiguousarray(im) # contiguous
|
226 |
+
|
227 |
+
return path, im, im0, self.cap, s
|
228 |
+
|
229 |
+
def _new_video(self, path):
|
230 |
+
# Create a new video capture object
|
231 |
+
self.frame = 0
|
232 |
+
self.cap = cv2.VideoCapture(path)
|
233 |
+
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
|
234 |
+
if hasattr(cv2, "CAP_PROP_ORIENTATION_META"): # cv2<4.6.0 compatibility
|
235 |
+
self.orientation = int(
|
236 |
+
self.cap.get(cv2.CAP_PROP_ORIENTATION_META)
|
237 |
+
) # rotation degrees
|
238 |
+
# Disable auto-orientation due to known issues in https://github.com/ultralytics/yolov5/issues/8493
|
239 |
+
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)
|
240 |
+
|
241 |
+
def _cv2_rotate(self, im):
|
242 |
+
# Rotate a cv2 video manually
|
243 |
+
if self.orientation == 0:
|
244 |
+
return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
|
245 |
+
elif self.orientation == 180:
|
246 |
+
return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
247 |
+
elif self.orientation == 90:
|
248 |
+
return cv2.rotate(im, cv2.ROTATE_180)
|
249 |
+
return im
|
250 |
+
|
251 |
+
def __len__(self):
|
252 |
+
return self.nf # number of files
|
253 |
+
|
254 |
+
|
255 |
+
class LetterBox:
|
256 |
+
"""Resize image and padding for detection, instance segmentation, pose"""
|
257 |
+
|
258 |
+
def __init__(
|
259 |
+
self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32
|
260 |
+
):
|
261 |
+
self.new_shape = new_shape
|
262 |
+
self.auto = auto
|
263 |
+
self.scaleFill = scaleFill
|
264 |
+
self.scaleup = scaleup
|
265 |
+
self.stride = stride
|
266 |
+
|
267 |
+
def __call__(self, labels=None, image=None):
|
268 |
+
if labels is None:
|
269 |
+
labels = {}
|
270 |
+
img = labels.get("img") if image is None else image
|
271 |
+
shape = img.shape[:2] # current shape [height, width]
|
272 |
+
new_shape = labels.pop("rect_shape", self.new_shape)
|
273 |
+
if isinstance(new_shape, int):
|
274 |
+
new_shape = (new_shape, new_shape)
|
275 |
+
|
276 |
+
# Scale ratio (new / old)
|
277 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
278 |
+
# only scale down, do not scale up (for better val mAP)
|
279 |
+
if not self.scaleup:
|
280 |
+
r = min(r, 1.0)
|
281 |
+
|
282 |
+
# Compute padding
|
283 |
+
ratio = r, r # width, height ratios
|
284 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
285 |
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
286 |
+
if self.auto: # minimum rectangle
|
287 |
+
dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
|
288 |
+
elif self.scaleFill: # stretch
|
289 |
+
dw, dh = 0.0, 0.0
|
290 |
+
new_unpad = (new_shape[1], new_shape[0])
|
291 |
+
ratio = (
|
292 |
+
new_shape[1] / shape[1],
|
293 |
+
new_shape[0] / shape[0],
|
294 |
+
) # width, height ratios
|
295 |
+
|
296 |
+
dw /= 2 # divide padding into 2 sides
|
297 |
+
dh /= 2
|
298 |
+
if labels.get("ratio_pad"):
|
299 |
+
labels["ratio_pad"] = (labels["ratio_pad"], (dw, dh)) # for evaluation
|
300 |
+
|
301 |
+
if shape[::-1] != new_unpad: # resize
|
302 |
+
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
303 |
+
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
304 |
+
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
305 |
+
img = cv2.copyMakeBorder(
|
306 |
+
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
|
307 |
+
) # add border
|
308 |
+
|
309 |
+
if len(labels):
|
310 |
+
labels = self._update_labels(labels, ratio, dw, dh)
|
311 |
+
labels["img"] = img
|
312 |
+
labels["resized_shape"] = new_shape
|
313 |
+
return labels
|
314 |
+
else:
|
315 |
+
return img
|
316 |
+
|
317 |
+
def _update_labels(self, labels, ratio, padw, padh):
|
318 |
+
"""Update labels"""
|
319 |
+
labels["instances"].convert_bbox(format="xyxy")
|
320 |
+
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
|
321 |
+
labels["instances"].scale(*ratio)
|
322 |
+
labels["instances"].add_padding(padw, padh)
|
323 |
+
return labels
|
324 |
+
|
325 |
+
|
326 |
+
class Annotator:
|
327 |
+
# YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
|
328 |
+
def __init__(
|
329 |
+
self,
|
330 |
+
im,
|
331 |
+
line_width=None,
|
332 |
+
font_size=None,
|
333 |
+
font="Arial.ttf",
|
334 |
+
pil=False,
|
335 |
+
example="abc",
|
336 |
+
):
|
337 |
+
assert (
|
338 |
+
im.data.contiguous
|
339 |
+
), "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images."
|
340 |
+
# non-latin labels, i.e. asian, arabic, cyrillic
|
341 |
+
non_ascii = not is_ascii(example)
|
342 |
+
self.pil = pil or non_ascii
|
343 |
+
if self.pil: # use PIL
|
344 |
+
self.pil_9_2_0_check = check_version(
|
345 |
+
pil_version, "9.2.0"
|
346 |
+
) # deprecation check
|
347 |
+
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
348 |
+
self.draw = ImageDraw.Draw(self.im)
|
349 |
+
self.font = ImageFont.load_default()
|
350 |
+
else: # use cv2
|
351 |
+
self.im = im
|
352 |
+
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
|
353 |
+
|
354 |
+
def box_label(
|
355 |
+
self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255)
|
356 |
+
):
|
357 |
+
# Add one xyxy box to image with label
|
358 |
+
if isinstance(box, torch.Tensor):
|
359 |
+
box = box.tolist()
|
360 |
+
if self.pil or not is_ascii(label):
|
361 |
+
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
362 |
+
if label:
|
363 |
+
if self.pil_9_2_0_check:
|
364 |
+
_, _, w, h = self.font.getbbox(label) # text width, height (New)
|
365 |
+
else:
|
366 |
+
w, h = self.font.getsize(
|
367 |
+
label
|
368 |
+
) # text width, height (Old, deprecated in 9.2.0)
|
369 |
+
outside = box[1] - h >= 0 # label fits outside box
|
370 |
+
self.draw.rectangle(
|
371 |
+
(
|
372 |
+
box[0],
|
373 |
+
box[1] - h if outside else box[1],
|
374 |
+
box[0] + w + 1,
|
375 |
+
box[1] + 1 if outside else box[1] + h + 1,
|
376 |
+
),
|
377 |
+
fill=color,
|
378 |
+
)
|
379 |
+
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
|
380 |
+
self.draw.text(
|
381 |
+
(box[0], box[1] - h if outside else box[1]),
|
382 |
+
label,
|
383 |
+
fill=txt_color,
|
384 |
+
font=self.font,
|
385 |
+
)
|
386 |
+
else: # cv2
|
387 |
+
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
388 |
+
cv2.rectangle(
|
389 |
+
self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA
|
390 |
+
)
|
391 |
+
if label:
|
392 |
+
tf = max(self.lw - 1, 1) # font thickness
|
393 |
+
# text width, height
|
394 |
+
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0]
|
395 |
+
outside = p1[1] - h >= 3
|
396 |
+
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
397 |
+
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
|
398 |
+
cv2.putText(
|
399 |
+
self.im,
|
400 |
+
label,
|
401 |
+
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
|
402 |
+
0,
|
403 |
+
self.lw / 3,
|
404 |
+
txt_color,
|
405 |
+
thickness=tf,
|
406 |
+
lineType=cv2.LINE_AA,
|
407 |
+
)
|
408 |
+
|
409 |
+
def rectangle(self, xy, fill=None, outline=None, width=1):
|
410 |
+
# Add rectangle to image (PIL-only)
|
411 |
+
self.draw.rectangle(xy, fill, outline, width)
|
412 |
+
|
413 |
+
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top"):
|
414 |
+
# Add text to image (PIL-only)
|
415 |
+
if anchor == "bottom": # start y from font bottom
|
416 |
+
w, h = self.font.getsize(text) # text width, height
|
417 |
+
xy[1] += 1 - h
|
418 |
+
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
419 |
+
|
420 |
+
def fromarray(self, im):
|
421 |
+
# Update self.im from a numpy array
|
422 |
+
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
423 |
+
self.draw = ImageDraw.Draw(self.im)
|
424 |
+
|
425 |
+
def result(self):
|
426 |
+
# Return annotated image as array
|
427 |
+
return np.asarray(self.im)
|
428 |
+
|
429 |
+
|
430 |
+
def non_max_suppression(
|
431 |
+
prediction,
|
432 |
+
conf_thres=0.25,
|
433 |
+
iou_thres=0.45,
|
434 |
+
classes=None,
|
435 |
+
agnostic=False,
|
436 |
+
multi_label=False,
|
437 |
+
labels=(),
|
438 |
+
max_det=300,
|
439 |
+
nm=0, # number of masks
|
440 |
+
):
|
441 |
+
# Checks
|
442 |
+
assert (
|
443 |
+
0 <= conf_thres <= 1
|
444 |
+
), f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
445 |
+
assert (
|
446 |
+
0 <= iou_thres <= 1
|
447 |
+
), f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
|
448 |
+
# YOLOv8 model in validation model, output = (inference_out, loss_out)
|
449 |
+
if isinstance(prediction, (list, tuple)):
|
450 |
+
prediction = prediction[0] # select only inference output
|
451 |
+
device = prediction.device
|
452 |
+
mps = "mps" in device.type # Apple MPS
|
453 |
+
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
454 |
+
prediction = prediction.cpu()
|
455 |
+
bs = prediction.shape[0] # batch size
|
456 |
+
nc = prediction.shape[1] - nm - 4 # number of classes
|
457 |
+
mi = 4 + nc # mask start index
|
458 |
+
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
|
459 |
+
|
460 |
+
# Settings
|
461 |
+
# min_wh = 2 # (pixels) minimum box width and height
|
462 |
+
max_wh = 7680 # (pixels) maximum box width and height
|
463 |
+
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
464 |
+
time_limit = 0.5 + 0.05 * bs # seconds to quit after
|
465 |
+
redundant = True # require redundant detections
|
466 |
+
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
467 |
+
merge = False # use merge-NMS
|
468 |
+
|
469 |
+
t = time.time()
|
470 |
+
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
471 |
+
for xi, x in enumerate(prediction): # image index, image inference
|
472 |
+
# Apply constraints
|
473 |
+
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
474 |
+
x = x.transpose(0, -1)[xc[xi]] # confidence
|
475 |
+
|
476 |
+
# Cat apriori labels if autolabelling
|
477 |
+
if labels and len(labels[xi]):
|
478 |
+
lb = labels[xi]
|
479 |
+
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
|
480 |
+
v[:, :4] = lb[:, 1:5] # box
|
481 |
+
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
|
482 |
+
x = torch.cat((x, v), 0)
|
483 |
+
|
484 |
+
# If none remain process next image
|
485 |
+
if not x.shape[0]:
|
486 |
+
continue
|
487 |
+
|
488 |
+
# Detections matrix nx6 (xyxy, conf, cls)
|
489 |
+
box, cls, mask = x.split((4, nc, nm), 1)
|
490 |
+
# center_x, center_y, width, height) to (x1, y1, x2, y2)
|
491 |
+
box = xywh2xyxy(box)
|
492 |
+
if multi_label:
|
493 |
+
i, j = (cls > conf_thres).nonzero(as_tuple=False).T
|
494 |
+
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
|
495 |
+
else: # best class only
|
496 |
+
conf, j = cls.max(1, keepdim=True)
|
497 |
+
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
498 |
+
|
499 |
+
# Filter by class
|
500 |
+
if classes is not None:
|
501 |
+
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
502 |
+
|
503 |
+
# Check shape
|
504 |
+
n = x.shape[0] # number of boxes
|
505 |
+
if not n: # no boxes
|
506 |
+
continue
|
507 |
+
# sort by confidence and remove excess boxes
|
508 |
+
x = x[x[:, 4].argsort(descending=True)[:max_nms]]
|
509 |
+
|
510 |
+
# Batched NMS
|
511 |
+
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
512 |
+
# boxes (offset by class), scores
|
513 |
+
boxes, scores = x[:, :4] + c, x[:, 4]
|
514 |
+
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
515 |
+
i = i[:max_det] # limit detections
|
516 |
+
if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
|
517 |
+
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
518 |
+
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
519 |
+
weights = iou * scores[None] # box weights
|
520 |
+
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(
|
521 |
+
1, keepdim=True
|
522 |
+
) # merged boxes
|
523 |
+
if redundant:
|
524 |
+
i = i[iou.sum(1) > 1] # require redundancy
|
525 |
+
|
526 |
+
output[xi] = x[i]
|
527 |
+
if mps:
|
528 |
+
output[xi] = output[xi].to(device)
|
529 |
+
if (time.time() - t) > time_limit:
|
530 |
+
LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
|
531 |
+
break # time limit exceeded
|
532 |
+
|
533 |
+
return output
|
534 |
+
|
535 |
+
|
536 |
+
class Colors:
|
537 |
+
# Ultralytics color palette https://ultralytics.com/
|
538 |
+
def __init__(self):
|
539 |
+
# hex = matplotlib.colors.TABLEAU_COLORS.values()
|
540 |
+
hexs = (
|
541 |
+
"FF3838",
|
542 |
+
"FF9D97",
|
543 |
+
"FF701F",
|
544 |
+
"FFB21D",
|
545 |
+
"CFD231",
|
546 |
+
"48F90A",
|
547 |
+
"92CC17",
|
548 |
+
"3DDB86",
|
549 |
+
"1A9334",
|
550 |
+
"00D4BB",
|
551 |
+
"2C99A8",
|
552 |
+
"00C2FF",
|
553 |
+
"344593",
|
554 |
+
"6473FF",
|
555 |
+
"0018EC",
|
556 |
+
"8438FF",
|
557 |
+
"520085",
|
558 |
+
"CB38FF",
|
559 |
+
"FF95C8",
|
560 |
+
"FF37C7",
|
561 |
+
)
|
562 |
+
self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
|
563 |
+
self.n = len(self.palette)
|
564 |
+
|
565 |
+
def __call__(self, i, bgr=False):
|
566 |
+
c = self.palette[int(i) % self.n]
|
567 |
+
return (c[2], c[1], c[0]) if bgr else c
|
568 |
+
|
569 |
+
@staticmethod
|
570 |
+
def hex2rgb(h): # rgb order (PIL)
|
571 |
+
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
|
572 |
+
|
573 |
+
|
574 |
+
colors = Colors() # create instance for 'from utils.plots import colors'
|
575 |
+
|
576 |
+
|
577 |
+
def threaded(func):
|
578 |
+
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
579 |
+
def wrapper(*args, **kwargs):
|
580 |
+
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
581 |
+
thread.start()
|
582 |
+
return thread
|
583 |
+
|
584 |
+
return wrapper
|
585 |
+
|
586 |
+
|
587 |
+
def plot_images(
|
588 |
+
images,
|
589 |
+
batch_idx,
|
590 |
+
cls,
|
591 |
+
bboxes,
|
592 |
+
masks=np.zeros(0, dtype=np.uint8),
|
593 |
+
paths=None,
|
594 |
+
fname="images.jpg",
|
595 |
+
names=None,
|
596 |
+
):
|
597 |
+
# Plot image grid with labels
|
598 |
+
if isinstance(images, torch.Tensor):
|
599 |
+
images = images.cpu().float().numpy()
|
600 |
+
if isinstance(cls, torch.Tensor):
|
601 |
+
cls = cls.cpu().numpy()
|
602 |
+
if isinstance(bboxes, torch.Tensor):
|
603 |
+
bboxes = bboxes.cpu().numpy()
|
604 |
+
if isinstance(masks, torch.Tensor):
|
605 |
+
masks = masks.cpu().numpy().astype(int)
|
606 |
+
if isinstance(batch_idx, torch.Tensor):
|
607 |
+
batch_idx = batch_idx.cpu().numpy()
|
608 |
+
|
609 |
+
max_size = 1920 # max image size
|
610 |
+
max_subplots = 16 # max image subplots, i.e. 4x4
|
611 |
+
bs, _, h, w = images.shape # batch size, _, height, width
|
612 |
+
bs = min(bs, max_subplots) # limit plot images
|
613 |
+
ns = np.ceil(bs**0.5) # number of subplots (square)
|
614 |
+
if np.max(images[0]) <= 1:
|
615 |
+
images *= 255 # de-normalise (optional)
|
616 |
+
|
617 |
+
# Build Image
|
618 |
+
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
619 |
+
for i, im in enumerate(images):
|
620 |
+
if i == max_subplots: # if last batch has fewer images than we expect
|
621 |
+
break
|
622 |
+
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
623 |
+
im = im.transpose(1, 2, 0)
|
624 |
+
mosaic[y : y + h, x : x + w, :] = im
|
625 |
+
|
626 |
+
# Resize (optional)
|
627 |
+
scale = max_size / ns / max(h, w)
|
628 |
+
if scale < 1:
|
629 |
+
h = math.ceil(scale * h)
|
630 |
+
w = math.ceil(scale * w)
|
631 |
+
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
|
632 |
+
|
633 |
+
# Annotate
|
634 |
+
fs = int((h + w) * ns * 0.01) # font size
|
635 |
+
annotator = Annotator(
|
636 |
+
mosaic, line_width=2, font_size=fs, pil=True, example=names
|
637 |
+
)
|
638 |
+
for i in range(i + 1):
|
639 |
+
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
640 |
+
annotator.rectangle(
|
641 |
+
[x, y, x + w, y + h], None, (255, 255, 255), width=2
|
642 |
+
) # borders
|
643 |
+
if paths:
|
644 |
+
annotator.text(
|
645 |
+
# filenames
|
646 |
+
(x + 5, y + 5 + h),
|
647 |
+
text=Path(paths[i]).name[:40],
|
648 |
+
txt_color=(220, 220, 220),
|
649 |
+
)
|
650 |
+
if len(cls) > 0:
|
651 |
+
idx = batch_idx == i
|
652 |
+
|
653 |
+
boxes = xywh2xyxy(bboxes[idx, :4]).T
|
654 |
+
classes = cls[idx].astype("int")
|
655 |
+
labels = bboxes.shape[1] == 4 # labels if no conf column
|
656 |
+
# check for confidence presence (label vs pred)
|
657 |
+
conf = None if labels else bboxes[idx, 4]
|
658 |
+
|
659 |
+
if boxes.shape[1]:
|
660 |
+
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
|
661 |
+
boxes[[0, 2]] *= w # scale to pixels
|
662 |
+
boxes[[1, 3]] *= h
|
663 |
+
elif scale < 1: # absolute coords need scale if image scales
|
664 |
+
boxes *= scale
|
665 |
+
boxes[[0, 2]] += x
|
666 |
+
boxes[[1, 3]] += y
|
667 |
+
for j, box in enumerate(boxes.T.tolist()):
|
668 |
+
c = classes[j]
|
669 |
+
color = colors(c)
|
670 |
+
c = names[c] if names else c
|
671 |
+
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
672 |
+
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
|
673 |
+
annotator.box_label(box, label, color=color)
|
674 |
+
annotator.im.save(fname) # save
|
675 |
+
|
676 |
+
|
677 |
+
def output_to_target(output, max_det=300):
|
678 |
+
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
|
679 |
+
targets = []
|
680 |
+
for i, o in enumerate(output):
|
681 |
+
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
|
682 |
+
j = torch.full((conf.shape[0], 1), i)
|
683 |
+
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
|
684 |
+
targets = torch.cat(targets, 0).numpy()
|
685 |
+
return targets[:, 0], targets[:, 1], targets[:, 2:]
|
686 |
+
|
687 |
+
|
688 |
+
def is_ascii(s=""):
|
689 |
+
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
|
690 |
+
s = str(s) # convert list, tuple, None, etc. to str
|
691 |
+
return len(s.encode().decode("ascii", "ignore")) == len(s)
|
692 |
+
|
693 |
+
|
694 |
+
def xyxy2xywh(x):
|
695 |
+
"""
|
696 |
+
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format.
|
697 |
+
|
698 |
+
Args:
|
699 |
+
x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
700 |
+
Returns:
|
701 |
+
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
702 |
+
"""
|
703 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
704 |
+
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
|
705 |
+
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
|
706 |
+
y[..., 2] = x[..., 2] - x[..., 0] # width
|
707 |
+
y[..., 3] = x[..., 3] - x[..., 1] # height
|
708 |
+
return y
|
709 |
+
|
710 |
+
|
711 |
+
def xywh2xyxy(x):
|
712 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
713 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
714 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
715 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
716 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
717 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
718 |
+
return y
|
719 |
+
|
720 |
+
|
721 |
+
def check_det_dataset(dataset, autodownload=True):
|
722 |
+
# Download, check and/or unzip dataset if not found locally
|
723 |
+
data = dataset
|
724 |
+
# Download (optional)
|
725 |
+
extract_dir = ''
|
726 |
+
|
727 |
+
# Read yaml (optional)
|
728 |
+
if isinstance(data, (str, Path)):
|
729 |
+
data = yaml_load(data, append_filename=True) # dictionary
|
730 |
+
|
731 |
+
# Checks
|
732 |
+
if isinstance(data['names'], (list, tuple)): # old array format
|
733 |
+
data['names'] = dict(enumerate(data['names'])) # convert to dict
|
734 |
+
data['nc'] = len(data['names'])
|
735 |
+
|
736 |
+
# Resolve paths
|
737 |
+
path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
|
738 |
+
|
739 |
+
DATASETS_DIR = os.path.abspath('.')
|
740 |
+
if not path.is_absolute():
|
741 |
+
path = (DATASETS_DIR / path).resolve()
|
742 |
+
data['path'] = path # download scripts
|
743 |
+
for k in 'train', 'val', 'test':
|
744 |
+
if data.get(k): # prepend path
|
745 |
+
if isinstance(data[k], str):
|
746 |
+
x = (path / data[k]).resolve()
|
747 |
+
if not x.exists() and data[k].startswith('../'):
|
748 |
+
x = (path / data[k][3:]).resolve()
|
749 |
+
data[k] = str(x)
|
750 |
+
else:
|
751 |
+
data[k] = [str((path / x).resolve()) for x in data[k]]
|
752 |
+
|
753 |
+
# Parse yaml
|
754 |
+
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
|
755 |
+
if val:
|
756 |
+
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
757 |
+
if not all(x.exists() for x in val):
|
758 |
+
msg = f"\nDataset '{dataset}' not found ⚠️, missing paths %s" % [str(x) for x in val if not x.exists()]
|
759 |
+
if s and autodownload:
|
760 |
+
LOGGER.warning(msg)
|
761 |
+
else:
|
762 |
+
raise FileNotFoundError(msg)
|
763 |
+
t = time.time()
|
764 |
+
if s.startswith('bash '): # bash script
|
765 |
+
LOGGER.info(f'Running {s} ...')
|
766 |
+
r = os.system(s)
|
767 |
+
else: # python script
|
768 |
+
r = exec(s, {'yaml': data}) # return None
|
769 |
+
dt = f'({round(time.time() - t, 1)}s)'
|
770 |
+
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
771 |
+
LOGGER.info(f"Dataset download {s}\n")
|
772 |
+
|
773 |
+
return data # dictionary
|
774 |
+
|
775 |
+
|
776 |
+
def yaml_load(file='data.yaml', append_filename=False):
|
777 |
+
"""
|
778 |
+
Load YAML data from a file.
|
779 |
+
|
780 |
+
Args:
|
781 |
+
file (str, optional): File name. Default is 'data.yaml'.
|
782 |
+
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
|
783 |
+
|
784 |
+
Returns:
|
785 |
+
dict: YAML data and file name.
|
786 |
+
"""
|
787 |
+
with open(file, errors='ignore', encoding='utf-8') as f:
|
788 |
+
# Add YAML filename to dict and return
|
789 |
+
s = f.read() # string
|
790 |
+
if not s.isprintable(): # remove special characters
|
791 |
+
s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
|
792 |
+
return {**yaml.safe_load(s), 'yaml_file': str(file)} if append_filename else yaml.safe_load(s)
|
793 |
+
|
794 |
+
|
795 |
+
class IterableSimpleNamespace(SimpleNamespace):
|
796 |
+
"""
|
797 |
+
Iterable SimpleNamespace class to allow SimpleNamespace to be used with dict() and in for loops
|
798 |
+
"""
|
799 |
+
|
800 |
+
def __iter__(self):
|
801 |
+
return iter(vars(self).items())
|
802 |
+
|
803 |
+
def __str__(self):
|
804 |
+
return '\n'.join(f"{k}={v}" for k, v in vars(self).items())
|
805 |
+
|
806 |
+
def get(self, key, default=None):
|
807 |
+
return getattr(self, key, default)
|
808 |
+
|
809 |
+
|
810 |
+
def colorstr(*input):
|
811 |
+
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
812 |
+
*args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
|
813 |
+
colors = {
|
814 |
+
"black": "\033[30m", # basic colors
|
815 |
+
"red": "\033[31m",
|
816 |
+
"green": "\033[32m",
|
817 |
+
"yellow": "\033[33m",
|
818 |
+
"blue": "\033[34m",
|
819 |
+
"magenta": "\033[35m",
|
820 |
+
"cyan": "\033[36m",
|
821 |
+
"white": "\033[37m",
|
822 |
+
"bright_black": "\033[90m", # bright colors
|
823 |
+
"bright_red": "\033[91m",
|
824 |
+
"bright_green": "\033[92m",
|
825 |
+
"bright_yellow": "\033[93m",
|
826 |
+
"bright_blue": "\033[94m",
|
827 |
+
"bright_magenta": "\033[95m",
|
828 |
+
"bright_cyan": "\033[96m",
|
829 |
+
"bright_white": "\033[97m",
|
830 |
+
"end": "\033[0m", # misc
|
831 |
+
"bold": "\033[1m",
|
832 |
+
"underline": "\033[4m"}
|
833 |
+
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
|
834 |
+
|
835 |
+
|
836 |
+
def seed_worker(worker_id):
|
837 |
+
# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
|
838 |
+
worker_seed = torch.initial_seed() % 2 ** 32
|
839 |
+
np.random.seed(worker_seed)
|
840 |
+
random.seed(worker_seed)
|
841 |
+
|
842 |
+
|
843 |
+
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode="train"):
|
844 |
+
assert mode in ["train", "val"]
|
845 |
+
shuffle = mode == "train"
|
846 |
+
if cfg.rect and shuffle:
|
847 |
+
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
848 |
+
shuffle = False
|
849 |
+
dataset = YOLODataset(
|
850 |
+
img_path=img_path,
|
851 |
+
imgsz=cfg.imgsz,
|
852 |
+
batch_size=batch,
|
853 |
+
augment=mode == "train", # augmentation
|
854 |
+
hyp=cfg,
|
855 |
+
rect=cfg.rect or rect, # rectangular batches
|
856 |
+
cache=cfg.cache or None,
|
857 |
+
single_cls=cfg.single_cls or False,
|
858 |
+
stride=int(stride),
|
859 |
+
pad=0.0 if mode == "train" else 0.5,
|
860 |
+
prefix=colorstr(f"{mode}: "),
|
861 |
+
use_segments=cfg.task == "segment",
|
862 |
+
use_keypoints=cfg.task == "keypoint",
|
863 |
+
names=names)
|
864 |
+
|
865 |
+
batch = min(batch, len(dataset))
|
866 |
+
nd = torch.cuda.device_count() # number of CUDA devices
|
867 |
+
workers = cfg.workers if mode == "train" else cfg.workers * 2
|
868 |
+
nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
|
869 |
+
|
870 |
+
if rank == -1:
|
871 |
+
sampler = None
|
872 |
+
if cfg.image_weights or cfg.close_mosaic:
|
873 |
+
loader = DataLoader
|
874 |
+
generator = torch.Generator()
|
875 |
+
generator.manual_seed(6148914691236517205)
|
876 |
+
return loader(dataset=dataset,
|
877 |
+
batch_size=batch,
|
878 |
+
shuffle=shuffle and sampler is None,
|
879 |
+
num_workers=nw,
|
880 |
+
sampler=sampler,
|
881 |
+
pin_memory=PIN_MEMORY,
|
882 |
+
collate_fn=getattr(dataset, "collate_fn", None),
|
883 |
+
worker_init_fn=seed_worker,
|
884 |
+
generator=generator), dataset
|
885 |
+
|
886 |
+
|
887 |
+
class BaseDataset(Dataset):
|
888 |
+
"""Base Dataset.
|
889 |
+
Args:
|
890 |
+
img_path (str): image path.
|
891 |
+
pipeline (dict): a dict of image transforms.
|
892 |
+
label_path (str): label path, this can also be an ann_file or other custom label path.
|
893 |
+
"""
|
894 |
+
|
895 |
+
def __init__(
|
896 |
+
self,
|
897 |
+
img_path,
|
898 |
+
imgsz=640,
|
899 |
+
cache=False,
|
900 |
+
augment=True,
|
901 |
+
hyp=None,
|
902 |
+
prefix="",
|
903 |
+
rect=False,
|
904 |
+
batch_size=None,
|
905 |
+
stride=32,
|
906 |
+
pad=0.5,
|
907 |
+
single_cls=False,
|
908 |
+
):
|
909 |
+
super().__init__()
|
910 |
+
self.img_path = img_path
|
911 |
+
self.imgsz = imgsz
|
912 |
+
self.augment = augment
|
913 |
+
self.single_cls = single_cls
|
914 |
+
self.prefix = prefix
|
915 |
+
self.im_files = self.get_img_files(self.img_path)
|
916 |
+
self.labels = self.get_labels()
|
917 |
+
self.ni = len(self.labels)
|
918 |
+
|
919 |
+
# rect stuff
|
920 |
+
self.rect = rect
|
921 |
+
self.batch_size = batch_size
|
922 |
+
self.stride = stride
|
923 |
+
self.pad = pad
|
924 |
+
if self.rect:
|
925 |
+
assert self.batch_size is not None
|
926 |
+
self.set_rectangle()
|
927 |
+
|
928 |
+
# cache stuff
|
929 |
+
self.ims = [None] * self.ni
|
930 |
+
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
|
931 |
+
if cache:
|
932 |
+
self.cache_images(cache)
|
933 |
+
|
934 |
+
# transforms
|
935 |
+
self.transforms = self.build_transforms(hyp=hyp)
|
936 |
+
|
937 |
+
def get_img_files(self, img_path):
|
938 |
+
"""Read image files."""
|
939 |
+
try:
|
940 |
+
f = [] # image files
|
941 |
+
for p in img_path if isinstance(img_path, list) else [img_path]:
|
942 |
+
p = Path(p) # os-agnostic
|
943 |
+
if p.is_dir(): # dir
|
944 |
+
f += glob.glob(str(p / "**" / "*.*"), recursive=True)
|
945 |
+
# f = list(p.rglob('*.*')) # pathlib
|
946 |
+
elif p.is_file(): # file
|
947 |
+
with open(p) as t:
|
948 |
+
t = t.read().strip().splitlines()
|
949 |
+
parent = str(p.parent) + os.sep
|
950 |
+
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
|
951 |
+
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
952 |
+
else:
|
953 |
+
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
|
954 |
+
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
|
955 |
+
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
956 |
+
assert im_files, f"{self.prefix}No images found"
|
957 |
+
except Exception as e:
|
958 |
+
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n") from e
|
959 |
+
return im_files
|
960 |
+
|
961 |
+
def load_image(self, i):
|
962 |
+
# Loads 1 image from dataset index 'i', returns (im, resized hw)
|
963 |
+
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
964 |
+
if im is None: # not cached in RAM
|
965 |
+
if fn.exists(): # load npy
|
966 |
+
im = np.load(fn)
|
967 |
+
else: # read image
|
968 |
+
im = cv2.imread(f) # BGR
|
969 |
+
if im is None:
|
970 |
+
raise FileNotFoundError(f"Image Not Found {f}")
|
971 |
+
h0, w0 = im.shape[:2] # orig hw
|
972 |
+
r = self.imgsz / max(h0, w0) # ratio
|
973 |
+
if r != 1: # if sizes are not equal
|
974 |
+
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
|
975 |
+
im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
|
976 |
+
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
977 |
+
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
|
978 |
+
|
979 |
+
def cache_images(self, cache):
|
980 |
+
# cache images to memory or disk
|
981 |
+
gb = 0 # Gigabytes of cached images
|
982 |
+
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
983 |
+
fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
|
984 |
+
with ThreadPool(NUM_THREADS) as pool:
|
985 |
+
results = pool.imap(fcn, range(self.ni))
|
986 |
+
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT)
|
987 |
+
for i, x in pbar:
|
988 |
+
if cache == "disk":
|
989 |
+
gb += self.npy_files[i].stat().st_size
|
990 |
+
else: # 'ram'
|
991 |
+
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
992 |
+
gb += self.ims[i].nbytes
|
993 |
+
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
|
994 |
+
pbar.close()
|
995 |
+
|
996 |
+
def cache_images_to_disk(self, i):
|
997 |
+
# Saves an image as an *.npy file for faster loading
|
998 |
+
f = self.npy_files[i]
|
999 |
+
if not f.exists():
|
1000 |
+
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
|
1001 |
+
|
1002 |
+
def set_rectangle(self):
|
1003 |
+
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
1004 |
+
nb = bi[-1] + 1 # number of batches
|
1005 |
+
|
1006 |
+
s = np.array([x.pop("shape") for x in self.labels]) # hw
|
1007 |
+
ar = s[:, 0] / s[:, 1] # aspect ratio
|
1008 |
+
irect = ar.argsort()
|
1009 |
+
self.im_files = [self.im_files[i] for i in irect]
|
1010 |
+
self.labels = [self.labels[i] for i in irect]
|
1011 |
+
ar = ar[irect]
|
1012 |
+
|
1013 |
+
# Set training image shapes
|
1014 |
+
shapes = [[1, 1]] * nb
|
1015 |
+
for i in range(nb):
|
1016 |
+
ari = ar[bi == i]
|
1017 |
+
mini, maxi = ari.min(), ari.max()
|
1018 |
+
if maxi < 1:
|
1019 |
+
shapes[i] = [maxi, 1]
|
1020 |
+
elif mini > 1:
|
1021 |
+
shapes[i] = [1, 1 / mini]
|
1022 |
+
|
1023 |
+
self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
|
1024 |
+
self.batch = bi # batch index of image
|
1025 |
+
|
1026 |
+
def __getitem__(self, index):
|
1027 |
+
return self.transforms(self.get_label_info(index))
|
1028 |
+
|
1029 |
+
def get_label_info(self, index):
|
1030 |
+
label = self.labels[index].copy()
|
1031 |
+
label.pop("shape", None) # shape is for rect, remove it
|
1032 |
+
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
|
1033 |
+
label["ratio_pad"] = (
|
1034 |
+
label["resized_shape"][0] / label["ori_shape"][0],
|
1035 |
+
label["resized_shape"][1] / label["ori_shape"][1],
|
1036 |
+
) # for evaluation
|
1037 |
+
if self.rect:
|
1038 |
+
label["rect_shape"] = self.batch_shapes[self.batch[index]]
|
1039 |
+
label = self.update_labels_info(label)
|
1040 |
+
return label
|
1041 |
+
|
1042 |
+
def __len__(self):
|
1043 |
+
return len(self.labels)
|
1044 |
+
|
1045 |
+
def update_labels_info(self, label):
|
1046 |
+
"""custom your label format here"""
|
1047 |
+
return label
|
1048 |
+
|
1049 |
+
def build_transforms(self, hyp=None):
|
1050 |
+
"""Users can custom augmentations here
|
1051 |
+
like:
|
1052 |
+
if self.augment:
|
1053 |
+
# training transforms
|
1054 |
+
return Compose([])
|
1055 |
+
else:
|
1056 |
+
# val transforms
|
1057 |
+
return Compose([])
|
1058 |
+
"""
|
1059 |
+
raise NotImplementedError
|
1060 |
+
|
1061 |
+
def get_labels(self):
|
1062 |
+
"""Users can custom their own format here.
|
1063 |
+
Make sure your output is a list with each element like below:
|
1064 |
+
dict(
|
1065 |
+
im_file=im_file,
|
1066 |
+
shape=shape, # format: (height, width)
|
1067 |
+
cls=cls,
|
1068 |
+
bboxes=bboxes, # xywh
|
1069 |
+
segments=segments, # xy
|
1070 |
+
keypoints=keypoints, # xy
|
1071 |
+
normalized=True, # or False
|
1072 |
+
bbox_format="xyxy", # or xywh, ltwh
|
1073 |
+
)
|
1074 |
+
"""
|
1075 |
+
raise NotImplementedError
|
1076 |
+
|
1077 |
+
|
1078 |
+
def img2label_paths(img_paths):
|
1079 |
+
# Define label paths as a function of image paths
|
1080 |
+
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
|
1081 |
+
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
|
1082 |
+
|
1083 |
+
|
1084 |
+
def get_hash(paths):
|
1085 |
+
# Returns a single hash value of a list of paths (files or dirs)
|
1086 |
+
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
|
1087 |
+
h = hashlib.md5(str(size).encode()) # hash sizes
|
1088 |
+
h.update("".join(paths).encode()) # hash paths
|
1089 |
+
return h.hexdigest() # return hash
|
1090 |
+
|
1091 |
+
|
1092 |
+
class Compose:
|
1093 |
+
|
1094 |
+
def __init__(self, transforms):
|
1095 |
+
self.transforms = transforms
|
1096 |
+
|
1097 |
+
def __call__(self, data):
|
1098 |
+
for t in self.transforms:
|
1099 |
+
data = t(data)
|
1100 |
+
return data
|
1101 |
+
|
1102 |
+
def append(self, transform):
|
1103 |
+
self.transforms.append(transform)
|
1104 |
+
|
1105 |
+
def tolist(self):
|
1106 |
+
return self.transforms
|
1107 |
+
|
1108 |
+
def __repr__(self):
|
1109 |
+
format_string = f"{self.__class__.__name__}("
|
1110 |
+
for t in self.transforms:
|
1111 |
+
format_string += "\n"
|
1112 |
+
format_string += f" {t}"
|
1113 |
+
format_string += "\n)"
|
1114 |
+
return format_string
|
1115 |
+
|
1116 |
+
|
1117 |
+
class Format:
|
1118 |
+
|
1119 |
+
def __init__(self,
|
1120 |
+
bbox_format="xywh",
|
1121 |
+
normalize=True,
|
1122 |
+
return_mask=False,
|
1123 |
+
return_keypoint=False,
|
1124 |
+
mask_ratio=4,
|
1125 |
+
mask_overlap=True,
|
1126 |
+
batch_idx=True):
|
1127 |
+
self.bbox_format = bbox_format
|
1128 |
+
self.normalize = normalize
|
1129 |
+
self.return_mask = return_mask # set False when training detection only
|
1130 |
+
self.return_keypoint = return_keypoint
|
1131 |
+
self.mask_ratio = mask_ratio
|
1132 |
+
self.mask_overlap = mask_overlap
|
1133 |
+
self.batch_idx = batch_idx # keep the batch indexes
|
1134 |
+
|
1135 |
+
def __call__(self, labels):
|
1136 |
+
img = labels.pop("img")
|
1137 |
+
h, w = img.shape[:2]
|
1138 |
+
cls = labels.pop("cls")
|
1139 |
+
instances = labels.pop("instances")
|
1140 |
+
instances.convert_bbox(format=self.bbox_format)
|
1141 |
+
instances.denormalize(w, h)
|
1142 |
+
nl = len(instances)
|
1143 |
+
|
1144 |
+
if self.normalize:
|
1145 |
+
instances.normalize(w, h)
|
1146 |
+
labels["img"] = self._format_img(img)
|
1147 |
+
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
1148 |
+
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
1149 |
+
if self.return_keypoint:
|
1150 |
+
labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
|
1151 |
+
# then we can use collate_fn
|
1152 |
+
if self.batch_idx:
|
1153 |
+
labels["batch_idx"] = torch.zeros(nl)
|
1154 |
+
return labels
|
1155 |
+
|
1156 |
+
def _format_img(self, img):
|
1157 |
+
if len(img.shape) < 3:
|
1158 |
+
img = np.expand_dims(img, -1)
|
1159 |
+
img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
|
1160 |
+
img = torch.from_numpy(img)
|
1161 |
+
return img
|
1162 |
+
|
1163 |
+
class Bboxes:
|
1164 |
+
"""Now only numpy is supported"""
|
1165 |
+
|
1166 |
+
def __init__(self, bboxes, format="xyxy") -> None:
|
1167 |
+
assert format in _formats
|
1168 |
+
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
|
1169 |
+
assert bboxes.ndim == 2
|
1170 |
+
assert bboxes.shape[1] == 4
|
1171 |
+
self.bboxes = bboxes
|
1172 |
+
self.format = format
|
1173 |
+
|
1174 |
+
def convert(self, format):
|
1175 |
+
assert format in _formats
|
1176 |
+
if self.format == format:
|
1177 |
+
return
|
1178 |
+
elif self.format == "xyxy":
|
1179 |
+
if format == "xywh":
|
1180 |
+
bboxes = xyxy2xywh(self.bboxes)
|
1181 |
+
elif self.format == "xywh":
|
1182 |
+
if format == "xyxy":
|
1183 |
+
bboxes = xywh2xyxy(self.bboxes)
|
1184 |
+
self.bboxes = bboxes
|
1185 |
+
self.format = format
|
1186 |
+
|
1187 |
+
def areas(self):
|
1188 |
+
self.convert("xyxy")
|
1189 |
+
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
|
1190 |
+
|
1191 |
+
def mul(self, scale):
|
1192 |
+
"""
|
1193 |
+
Args:
|
1194 |
+
scale (tuple | List | int): the scale for four coords.
|
1195 |
+
"""
|
1196 |
+
assert isinstance(scale, (tuple, list))
|
1197 |
+
assert len(scale) == 4
|
1198 |
+
self.bboxes[:, 0] *= scale[0]
|
1199 |
+
self.bboxes[:, 1] *= scale[1]
|
1200 |
+
self.bboxes[:, 2] *= scale[2]
|
1201 |
+
self.bboxes[:, 3] *= scale[3]
|
1202 |
+
|
1203 |
+
def add(self, offset):
|
1204 |
+
"""
|
1205 |
+
Args:
|
1206 |
+
offset (tuple | List | int): the offset for four coords.
|
1207 |
+
"""
|
1208 |
+
assert isinstance(offset, (tuple, list))
|
1209 |
+
assert len(offset) == 4
|
1210 |
+
self.bboxes[:, 0] += offset[0]
|
1211 |
+
self.bboxes[:, 1] += offset[1]
|
1212 |
+
self.bboxes[:, 2] += offset[2]
|
1213 |
+
self.bboxes[:, 3] += offset[3]
|
1214 |
+
|
1215 |
+
def __len__(self):
|
1216 |
+
return len(self.bboxes)
|
1217 |
+
|
1218 |
+
@classmethod
|
1219 |
+
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
|
1220 |
+
"""
|
1221 |
+
Concatenates a list of Boxes into a single Bboxes
|
1222 |
+
|
1223 |
+
Arguments:
|
1224 |
+
boxes_list (list[Bboxes])
|
1225 |
+
|
1226 |
+
Returns:
|
1227 |
+
Bboxes: the concatenated Boxes
|
1228 |
+
"""
|
1229 |
+
assert isinstance(boxes_list, (list, tuple))
|
1230 |
+
if not boxes_list:
|
1231 |
+
return cls(np.empty(0))
|
1232 |
+
assert all(isinstance(box, Bboxes) for box in boxes_list)
|
1233 |
+
|
1234 |
+
if len(boxes_list) == 1:
|
1235 |
+
return boxes_list[0]
|
1236 |
+
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
|
1237 |
+
|
1238 |
+
def __getitem__(self, index) -> "Bboxes":
|
1239 |
+
"""
|
1240 |
+
Args:
|
1241 |
+
index: int, slice, or a BoolArray
|
1242 |
+
|
1243 |
+
Returns:
|
1244 |
+
Bboxes: Create a new :class:`Bboxes` by indexing.
|
1245 |
+
"""
|
1246 |
+
if isinstance(index, int):
|
1247 |
+
return Bboxes(self.bboxes[index].view(1, -1))
|
1248 |
+
b = self.bboxes[index]
|
1249 |
+
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
|
1250 |
+
return Bboxes(b)
|
1251 |
+
|
1252 |
+
|
1253 |
+
def resample_segments(segments, n=1000):
|
1254 |
+
"""
|
1255 |
+
Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
|
1256 |
+
|
1257 |
+
Args:
|
1258 |
+
segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
|
1259 |
+
n (int): number of points to resample the segment to. Defaults to 1000
|
1260 |
+
|
1261 |
+
Returns:
|
1262 |
+
segments (list): the resampled segments.
|
1263 |
+
"""
|
1264 |
+
for i, s in enumerate(segments):
|
1265 |
+
s = np.concatenate((s, s[0:1, :]), axis=0)
|
1266 |
+
x = np.linspace(0, len(s) - 1, n)
|
1267 |
+
xp = np.arange(len(s))
|
1268 |
+
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
|
1269 |
+
return segments
|
1270 |
+
|
1271 |
+
|
1272 |
+
class Instances:
|
1273 |
+
|
1274 |
+
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
|
1275 |
+
"""
|
1276 |
+
Args:
|
1277 |
+
bboxes (ndarray): bboxes with shape [N, 4].
|
1278 |
+
segments (list | ndarray): segments.
|
1279 |
+
keypoints (ndarray): keypoints with shape [N, 17, 2].
|
1280 |
+
"""
|
1281 |
+
if segments is None:
|
1282 |
+
segments = []
|
1283 |
+
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
|
1284 |
+
self.keypoints = keypoints
|
1285 |
+
self.normalized = normalized
|
1286 |
+
|
1287 |
+
if len(segments) > 0:
|
1288 |
+
# list[np.array(1000, 2)] * num_samples
|
1289 |
+
segments = resample_segments(segments)
|
1290 |
+
# (N, 1000, 2)
|
1291 |
+
segments = np.stack(segments, axis=0)
|
1292 |
+
else:
|
1293 |
+
segments = np.zeros((0, 1000, 2), dtype=np.float32)
|
1294 |
+
self.segments = segments
|
1295 |
+
|
1296 |
+
def convert_bbox(self, format):
|
1297 |
+
self._bboxes.convert(format=format)
|
1298 |
+
|
1299 |
+
def bbox_areas(self):
|
1300 |
+
self._bboxes.areas()
|
1301 |
+
|
1302 |
+
def scale(self, scale_w, scale_h, bbox_only=False):
|
1303 |
+
"""this might be similar with denormalize func but without normalized sign"""
|
1304 |
+
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
|
1305 |
+
if bbox_only:
|
1306 |
+
return
|
1307 |
+
self.segments[..., 0] *= scale_w
|
1308 |
+
self.segments[..., 1] *= scale_h
|
1309 |
+
if self.keypoints is not None:
|
1310 |
+
self.keypoints[..., 0] *= scale_w
|
1311 |
+
self.keypoints[..., 1] *= scale_h
|
1312 |
+
|
1313 |
+
def denormalize(self, w, h):
|
1314 |
+
if not self.normalized:
|
1315 |
+
return
|
1316 |
+
self._bboxes.mul(scale=(w, h, w, h))
|
1317 |
+
self.segments[..., 0] *= w
|
1318 |
+
self.segments[..., 1] *= h
|
1319 |
+
if self.keypoints is not None:
|
1320 |
+
self.keypoints[..., 0] *= w
|
1321 |
+
self.keypoints[..., 1] *= h
|
1322 |
+
self.normalized = False
|
1323 |
+
|
1324 |
+
def normalize(self, w, h):
|
1325 |
+
if self.normalized:
|
1326 |
+
return
|
1327 |
+
self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
|
1328 |
+
self.segments[..., 0] /= w
|
1329 |
+
self.segments[..., 1] /= h
|
1330 |
+
if self.keypoints is not None:
|
1331 |
+
self.keypoints[..., 0] /= w
|
1332 |
+
self.keypoints[..., 1] /= h
|
1333 |
+
self.normalized = True
|
1334 |
+
|
1335 |
+
def add_padding(self, padw, padh):
|
1336 |
+
# handle rect and mosaic situation
|
1337 |
+
assert not self.normalized, "you should add padding with absolute coordinates."
|
1338 |
+
self._bboxes.add(offset=(padw, padh, padw, padh))
|
1339 |
+
self.segments[..., 0] += padw
|
1340 |
+
self.segments[..., 1] += padh
|
1341 |
+
if self.keypoints is not None:
|
1342 |
+
self.keypoints[..., 0] += padw
|
1343 |
+
self.keypoints[..., 1] += padh
|
1344 |
+
|
1345 |
+
def __getitem__(self, index) -> "Instances":
|
1346 |
+
"""
|
1347 |
+
Args:
|
1348 |
+
index: int, slice, or a BoolArray
|
1349 |
+
|
1350 |
+
Returns:
|
1351 |
+
Instances: Create a new :class:`Instances` by indexing.
|
1352 |
+
"""
|
1353 |
+
segments = self.segments[index] if len(self.segments) else self.segments
|
1354 |
+
keypoints = self.keypoints[index] if self.keypoints is not None else None
|
1355 |
+
bboxes = self.bboxes[index]
|
1356 |
+
bbox_format = self._bboxes.format
|
1357 |
+
return Instances(
|
1358 |
+
bboxes=bboxes,
|
1359 |
+
segments=segments,
|
1360 |
+
keypoints=keypoints,
|
1361 |
+
bbox_format=bbox_format,
|
1362 |
+
normalized=self.normalized,
|
1363 |
+
)
|
1364 |
+
|
1365 |
+
def flipud(self, h):
|
1366 |
+
if self._bboxes.format == "xyxy":
|
1367 |
+
y1 = self.bboxes[:, 1].copy()
|
1368 |
+
y2 = self.bboxes[:, 3].copy()
|
1369 |
+
self.bboxes[:, 1] = h - y2
|
1370 |
+
self.bboxes[:, 3] = h - y1
|
1371 |
+
else:
|
1372 |
+
self.bboxes[:, 1] = h - self.bboxes[:, 1]
|
1373 |
+
self.segments[..., 1] = h - self.segments[..., 1]
|
1374 |
+
if self.keypoints is not None:
|
1375 |
+
self.keypoints[..., 1] = h - self.keypoints[..., 1]
|
1376 |
+
|
1377 |
+
def fliplr(self, w):
|
1378 |
+
if self._bboxes.format == "xyxy":
|
1379 |
+
x1 = self.bboxes[:, 0].copy()
|
1380 |
+
x2 = self.bboxes[:, 2].copy()
|
1381 |
+
self.bboxes[:, 0] = w - x2
|
1382 |
+
self.bboxes[:, 2] = w - x1
|
1383 |
+
else:
|
1384 |
+
self.bboxes[:, 0] = w - self.bboxes[:, 0]
|
1385 |
+
self.segments[..., 0] = w - self.segments[..., 0]
|
1386 |
+
if self.keypoints is not None:
|
1387 |
+
self.keypoints[..., 0] = w - self.keypoints[..., 0]
|
1388 |
+
|
1389 |
+
def clip(self, w, h):
|
1390 |
+
ori_format = self._bboxes.format
|
1391 |
+
self.convert_bbox(format="xyxy")
|
1392 |
+
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
|
1393 |
+
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
|
1394 |
+
if ori_format != "xyxy":
|
1395 |
+
self.convert_bbox(format=ori_format)
|
1396 |
+
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
|
1397 |
+
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
|
1398 |
+
if self.keypoints is not None:
|
1399 |
+
self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
|
1400 |
+
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
|
1401 |
+
|
1402 |
+
def update(self, bboxes, segments=None, keypoints=None):
|
1403 |
+
new_bboxes = Bboxes(bboxes, format=self._bboxes.format)
|
1404 |
+
self._bboxes = new_bboxes
|
1405 |
+
if segments is not None:
|
1406 |
+
self.segments = segments
|
1407 |
+
if keypoints is not None:
|
1408 |
+
self.keypoints = keypoints
|
1409 |
+
|
1410 |
+
def __len__(self):
|
1411 |
+
return len(self.bboxes)
|
1412 |
+
|
1413 |
+
@classmethod
|
1414 |
+
def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
|
1415 |
+
"""
|
1416 |
+
Concatenates a list of Boxes into a single Bboxes
|
1417 |
+
|
1418 |
+
Arguments:
|
1419 |
+
instances_list (list[Bboxes])
|
1420 |
+
axis
|
1421 |
+
|
1422 |
+
Returns:
|
1423 |
+
Boxes: the concatenated Boxes
|
1424 |
+
"""
|
1425 |
+
assert isinstance(instances_list, (list, tuple))
|
1426 |
+
if not instances_list:
|
1427 |
+
return cls(np.empty(0))
|
1428 |
+
assert all(isinstance(instance, Instances) for instance in instances_list)
|
1429 |
+
|
1430 |
+
if len(instances_list) == 1:
|
1431 |
+
return instances_list[0]
|
1432 |
+
|
1433 |
+
use_keypoint = instances_list[0].keypoints is not None
|
1434 |
+
bbox_format = instances_list[0]._bboxes.format
|
1435 |
+
normalized = instances_list[0].normalized
|
1436 |
+
|
1437 |
+
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
|
1438 |
+
cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
|
1439 |
+
cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
|
1440 |
+
return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)
|
1441 |
+
|
1442 |
+
@property
|
1443 |
+
def bboxes(self):
|
1444 |
+
return self._bboxes.bboxes
|
1445 |
+
|
1446 |
+
|
1447 |
+
def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
|
1448 |
+
"""
|
1449 |
+
Check if a directory is writeable.
|
1450 |
+
|
1451 |
+
Args:
|
1452 |
+
dir_path (str) or (Path): The path to the directory.
|
1453 |
+
|
1454 |
+
Returns:
|
1455 |
+
bool: True if the directory is writeable, False otherwise.
|
1456 |
+
"""
|
1457 |
+
try:
|
1458 |
+
with tempfile.TemporaryFile(dir=dir_path):
|
1459 |
+
pass
|
1460 |
+
return True
|
1461 |
+
except OSError:
|
1462 |
+
return False
|
1463 |
+
|
1464 |
+
|
1465 |
+
class YOLODataset(BaseDataset):
|
1466 |
+
cache_version = '1.0.1' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
|
1467 |
+
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
|
1468 |
+
"""YOLO Dataset.
|
1469 |
+
Args:
|
1470 |
+
img_path (str): image path.
|
1471 |
+
prefix (str): prefix.
|
1472 |
+
"""
|
1473 |
+
|
1474 |
+
def __init__(self,
|
1475 |
+
img_path,
|
1476 |
+
imgsz=640,
|
1477 |
+
cache=False,
|
1478 |
+
augment=True,
|
1479 |
+
hyp=None,
|
1480 |
+
prefix="",
|
1481 |
+
rect=False,
|
1482 |
+
batch_size=None,
|
1483 |
+
stride=32,
|
1484 |
+
pad=0.0,
|
1485 |
+
single_cls=False,
|
1486 |
+
use_segments=False,
|
1487 |
+
use_keypoints=False,
|
1488 |
+
names=None):
|
1489 |
+
self.use_segments = use_segments
|
1490 |
+
self.use_keypoints = use_keypoints
|
1491 |
+
self.names = names
|
1492 |
+
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
1493 |
+
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
|
1494 |
+
|
1495 |
+
def cache_labels(self, path=Path("./labels.cache")):
|
1496 |
+
# Cache dataset labels, check images and read shapes
|
1497 |
+
if path.exists():
|
1498 |
+
path.unlink() # remove *.cache file if exists
|
1499 |
+
x = {"labels": []}
|
1500 |
+
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
1501 |
+
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
1502 |
+
total = len(self.im_files)
|
1503 |
+
with ThreadPool(NUM_THREADS) as pool:
|
1504 |
+
results = pool.imap(func=verify_image_label,
|
1505 |
+
iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
|
1506 |
+
repeat(self.use_keypoints), repeat(len(self.names))))
|
1507 |
+
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
|
1508 |
+
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
1509 |
+
nm += nm_f
|
1510 |
+
nf += nf_f
|
1511 |
+
ne += ne_f
|
1512 |
+
nc += nc_f
|
1513 |
+
if im_file:
|
1514 |
+
x["labels"].append(
|
1515 |
+
dict(
|
1516 |
+
im_file=im_file,
|
1517 |
+
shape=shape,
|
1518 |
+
cls=lb[:, 0:1], # n, 1
|
1519 |
+
bboxes=lb[:, 1:], # n, 4
|
1520 |
+
segments=segments,
|
1521 |
+
keypoints=keypoint,
|
1522 |
+
normalized=True,
|
1523 |
+
bbox_format="xywh"))
|
1524 |
+
if msg:
|
1525 |
+
msgs.append(msg)
|
1526 |
+
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
1527 |
+
pbar.close()
|
1528 |
+
|
1529 |
+
if msgs:
|
1530 |
+
LOGGER.info("\n".join(msgs))
|
1531 |
+
x["hash"] = get_hash(self.label_files + self.im_files)
|
1532 |
+
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
1533 |
+
x["msgs"] = msgs # warnings
|
1534 |
+
x["version"] = self.cache_version # cache version
|
1535 |
+
self.im_files = [lb["im_file"] for lb in x["labels"]] # update im_files
|
1536 |
+
if is_dir_writeable(path.parent):
|
1537 |
+
np.save(str(path), x) # save cache for next time
|
1538 |
+
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
1539 |
+
LOGGER.info(f"{self.prefix}New cache created: {path}")
|
1540 |
+
else:
|
1541 |
+
LOGGER.warning(f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable") # not writeable
|
1542 |
+
return x
|
1543 |
+
|
1544 |
+
def get_labels(self):
|
1545 |
+
self.label_files = img2label_paths(self.im_files)
|
1546 |
+
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
1547 |
+
try:
|
1548 |
+
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
|
1549 |
+
assert cache["version"] == self.cache_version # matches current version
|
1550 |
+
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
1551 |
+
except (FileNotFoundError, AssertionError, AttributeError):
|
1552 |
+
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
1553 |
+
|
1554 |
+
# Display cache
|
1555 |
+
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
1556 |
+
if exists:
|
1557 |
+
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
1558 |
+
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
1559 |
+
if cache["msgs"]:
|
1560 |
+
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
1561 |
+
|
1562 |
+
# Read cache
|
1563 |
+
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
1564 |
+
labels = cache["labels"]
|
1565 |
+
|
1566 |
+
# Check if the dataset is all boxes or all segments
|
1567 |
+
len_cls = sum(len(lb["cls"]) for lb in labels)
|
1568 |
+
len_boxes = sum(len(lb["bboxes"]) for lb in labels)
|
1569 |
+
len_segments = sum(len(lb["segments"]) for lb in labels)
|
1570 |
+
if len_segments and len_boxes != len_segments:
|
1571 |
+
LOGGER.warning(
|
1572 |
+
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
|
1573 |
+
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
|
1574 |
+
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")
|
1575 |
+
for lb in labels:
|
1576 |
+
lb["segments"] = []
|
1577 |
+
return labels
|
1578 |
+
|
1579 |
+
|
1580 |
+
def build_transforms(self, hyp=None):
|
1581 |
+
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
1582 |
+
transforms.append(
|
1583 |
+
Format(bbox_format="xywh",
|
1584 |
+
normalize=True,
|
1585 |
+
return_mask=self.use_segments,
|
1586 |
+
return_keypoint=self.use_keypoints,
|
1587 |
+
batch_idx=True,
|
1588 |
+
mask_ratio=hyp.mask_ratio,
|
1589 |
+
mask_overlap=hyp.overlap_mask))
|
1590 |
+
return transforms
|
1591 |
+
|
1592 |
+
def close_mosaic(self, hyp):
|
1593 |
+
hyp.mosaic = 0.0 # set mosaic ratio=0.0
|
1594 |
+
hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
|
1595 |
+
hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
|
1596 |
+
self.transforms = self.build_transforms(hyp)
|
1597 |
+
|
1598 |
+
def update_labels_info(self, label):
|
1599 |
+
"""custom your label format here"""
|
1600 |
+
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
1601 |
+
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
1602 |
+
bboxes = label.pop("bboxes")
|
1603 |
+
segments = label.pop("segments")
|
1604 |
+
keypoints = label.pop("keypoints", None)
|
1605 |
+
bbox_format = label.pop("bbox_format")
|
1606 |
+
normalized = label.pop("normalized")
|
1607 |
+
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
1608 |
+
return label
|
1609 |
+
|
1610 |
+
@staticmethod
|
1611 |
+
def collate_fn(batch):
|
1612 |
+
new_batch = {}
|
1613 |
+
keys = batch[0].keys()
|
1614 |
+
values = list(zip(*[list(b.values()) for b in batch]))
|
1615 |
+
for i, k in enumerate(keys):
|
1616 |
+
value = values[i]
|
1617 |
+
if k == "img":
|
1618 |
+
value = torch.stack(value, 0)
|
1619 |
+
if k in ["masks", "keypoints", "bboxes", "cls"]:
|
1620 |
+
value = torch.cat(value, 0)
|
1621 |
+
new_batch[k] = value
|
1622 |
+
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
1623 |
+
for i in range(len(new_batch["batch_idx"])):
|
1624 |
+
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
1625 |
+
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
|
1626 |
+
return new_batch
|
1627 |
+
|
1628 |
+
|
1629 |
+
class DFL(nn.Module):
|
1630 |
+
# Integral module of Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
|
1631 |
+
def __init__(self, c1=16):
|
1632 |
+
super().__init__()
|
1633 |
+
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
1634 |
+
x = torch.arange(c1, dtype=torch.float)
|
1635 |
+
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
|
1636 |
+
self.c1 = c1
|
1637 |
+
|
1638 |
+
def forward(self, x):
|
1639 |
+
b, c, a = x.shape # batch, channels, anchors
|
1640 |
+
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(
|
1641 |
+
b, 4, a
|
1642 |
+
)
|
1643 |
+
|
1644 |
+
|
1645 |
+
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
1646 |
+
"""Transform distance(ltrb) to box(xywh or xyxy)."""
|
1647 |
+
lt, rb = torch.split(distance, 2, dim)
|
1648 |
+
x1y1 = anchor_points - lt
|
1649 |
+
x2y2 = anchor_points + rb
|
1650 |
+
if xywh:
|
1651 |
+
c_xy = (x1y1 + x2y2) / 2
|
1652 |
+
wh = x2y2 - x1y1
|
1653 |
+
return torch.cat((c_xy, wh), dim) # xywh bbox
|
1654 |
+
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
|
1655 |
+
|
1656 |
+
|
1657 |
+
def post_process(x):
|
1658 |
+
dfl = DFL(16)
|
1659 |
+
anchors = torch.tensor(
|
1660 |
+
np.load(
|
1661 |
+
"./anchors.npy",
|
1662 |
+
allow_pickle=True,
|
1663 |
+
)
|
1664 |
+
)
|
1665 |
+
strides = torch.tensor(
|
1666 |
+
np.load(
|
1667 |
+
"./strides.npy",
|
1668 |
+
allow_pickle=True,
|
1669 |
+
)
|
1670 |
+
)
|
1671 |
+
box, cls = torch.cat([xi.view(x[0].shape[0], 144, -1) for xi in x], 2).split(
|
1672 |
+
(16 * 4, 80), 1
|
1673 |
+
)
|
1674 |
+
dbox = dist2bbox(dfl(box), anchors.unsqueeze(0), xywh=True, dim=1) * strides
|
1675 |
+
y = torch.cat((dbox, cls.sigmoid()), 1)
|
1676 |
+
return y, x
|
1677 |
+
|
1678 |
+
|
1679 |
+
def smooth(y, f=0.05):
|
1680 |
+
# Box filter of fraction f
|
1681 |
+
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
|
1682 |
+
p = np.ones(nf // 2) # ones padding
|
1683 |
+
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
|
1684 |
+
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
|
1685 |
+
|
1686 |
+
|
1687 |
+
def compute_ap(recall, precision):
|
1688 |
+
""" Compute the average precision, given the recall and precision curves
|
1689 |
+
# Arguments
|
1690 |
+
recall: The recall curve (list)
|
1691 |
+
precision: The precision curve (list)
|
1692 |
+
# Returns
|
1693 |
+
Average precision, precision curve, recall curve
|
1694 |
+
"""
|
1695 |
+
|
1696 |
+
# Append sentinel values to beginning and end
|
1697 |
+
mrec = np.concatenate(([0.0], recall, [1.0]))
|
1698 |
+
mpre = np.concatenate(([1.0], precision, [0.0]))
|
1699 |
+
|
1700 |
+
# Compute the precision envelope
|
1701 |
+
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
|
1702 |
+
|
1703 |
+
# Integrate area under curve
|
1704 |
+
method = 'interp' # methods: 'continuous', 'interp'
|
1705 |
+
if method == 'interp':
|
1706 |
+
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
|
1707 |
+
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
|
1708 |
+
else: # 'continuous'
|
1709 |
+
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x-axis (recall) changes
|
1710 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
|
1711 |
+
|
1712 |
+
return ap, mpre, mrec
|
1713 |
+
|
1714 |
+
|
1715 |
+
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=""):
|
1716 |
+
""" Compute the average precision, given the recall and precision curves.
|
1717 |
+
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
|
1718 |
+
# Arguments
|
1719 |
+
tp: True positives (nparray, nx1 or nx10).
|
1720 |
+
conf: Objectness value from 0-1 (nparray).
|
1721 |
+
pred_cls: Predicted object classes (nparray).
|
1722 |
+
target_cls: True object classes (nparray).
|
1723 |
+
plot: Plot precision-recall curve at mAP@0.5
|
1724 |
+
save_dir: Plot save directory
|
1725 |
+
# Returns
|
1726 |
+
The average precision as computed in py-faster-rcnn.
|
1727 |
+
"""
|
1728 |
+
|
1729 |
+
# Sort by objectness
|
1730 |
+
i = np.argsort(-conf)
|
1731 |
+
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
|
1732 |
+
|
1733 |
+
# Find unique classes
|
1734 |
+
unique_classes, nt = np.unique(target_cls, return_counts=True)
|
1735 |
+
nc = unique_classes.shape[0] # number of classes, number of detections
|
1736 |
+
|
1737 |
+
# Create Precision-Recall curve and compute AP for each class
|
1738 |
+
px, py = np.linspace(0, 1, 1000), [] # for plotting
|
1739 |
+
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
|
1740 |
+
for ci, c in enumerate(unique_classes):
|
1741 |
+
i = pred_cls == c
|
1742 |
+
n_l = nt[ci] # number of labels
|
1743 |
+
n_p = i.sum() # number of predictions
|
1744 |
+
if n_p == 0 or n_l == 0:
|
1745 |
+
continue
|
1746 |
+
|
1747 |
+
# Accumulate FPs and TPs
|
1748 |
+
fpc = (1 - tp[i]).cumsum(0)
|
1749 |
+
tpc = tp[i].cumsum(0)
|
1750 |
+
|
1751 |
+
# Recall
|
1752 |
+
recall = tpc / (n_l + eps) # recall curve
|
1753 |
+
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
|
1754 |
+
|
1755 |
+
# Precision
|
1756 |
+
precision = tpc / (tpc + fpc) # precision curve
|
1757 |
+
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
|
1758 |
+
|
1759 |
+
# AP from recall-precision curve
|
1760 |
+
for j in range(tp.shape[1]):
|
1761 |
+
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
|
1762 |
+
if plot and j == 0:
|
1763 |
+
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
|
1764 |
+
|
1765 |
+
# Compute F1 (harmonic mean of precision and recall)
|
1766 |
+
f1 = 2 * p * r / (p + r + eps)
|
1767 |
+
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
|
1768 |
+
names = dict(enumerate(names)) # to dict
|
1769 |
+
|
1770 |
+
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
|
1771 |
+
p, r, f1 = p[:, i], r[:, i], f1[:, i]
|
1772 |
+
tp = (r * nt).round() # true positives
|
1773 |
+
fp = (tp / (p + eps) - tp).round() # false positives
|
1774 |
+
return tp, fp, p, r, f1, ap, unique_classes.astype(int)
|
1775 |
+
|
1776 |
+
|
1777 |
+
class Metric:
|
1778 |
+
|
1779 |
+
def __init__(self) -> None:
|
1780 |
+
self.p = [] # (nc, )
|
1781 |
+
self.r = [] # (nc, )
|
1782 |
+
self.f1 = [] # (nc, )
|
1783 |
+
self.all_ap = [] # (nc, 10)
|
1784 |
+
self.ap_class_index = [] # (nc, )
|
1785 |
+
self.nc = 0
|
1786 |
+
|
1787 |
+
@property
|
1788 |
+
def ap50(self):
|
1789 |
+
"""AP@0.5 of all classes.
|
1790 |
+
Return:
|
1791 |
+
(nc, ) or [].
|
1792 |
+
"""
|
1793 |
+
return self.all_ap[:, 0] if len(self.all_ap) else []
|
1794 |
+
|
1795 |
+
@property
|
1796 |
+
def ap(self):
|
1797 |
+
"""AP@0.5:0.95
|
1798 |
+
Return:
|
1799 |
+
(nc, ) or [].
|
1800 |
+
"""
|
1801 |
+
return self.all_ap.mean(1) if len(self.all_ap) else []
|
1802 |
+
|
1803 |
+
@property
|
1804 |
+
def mp(self):
|
1805 |
+
"""mean precision of all classes.
|
1806 |
+
Return:
|
1807 |
+
float.
|
1808 |
+
"""
|
1809 |
+
return self.p.mean() if len(self.p) else 0.0
|
1810 |
+
|
1811 |
+
@property
|
1812 |
+
def mr(self):
|
1813 |
+
"""mean recall of all classes.
|
1814 |
+
Return:
|
1815 |
+
float.
|
1816 |
+
"""
|
1817 |
+
return self.r.mean() if len(self.r) else 0.0
|
1818 |
+
|
1819 |
+
@property
|
1820 |
+
def map50(self):
|
1821 |
+
"""Mean AP@0.5 of all classes.
|
1822 |
+
Return:
|
1823 |
+
float.
|
1824 |
+
"""
|
1825 |
+
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
|
1826 |
+
|
1827 |
+
@property
|
1828 |
+
def map75(self):
|
1829 |
+
"""Mean AP@0.75 of all classes.
|
1830 |
+
Return:
|
1831 |
+
float.
|
1832 |
+
"""
|
1833 |
+
return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
|
1834 |
+
|
1835 |
+
@property
|
1836 |
+
def map(self):
|
1837 |
+
"""Mean AP@0.5:0.95 of all classes.
|
1838 |
+
Return:
|
1839 |
+
float.
|
1840 |
+
"""
|
1841 |
+
return self.all_ap.mean() if len(self.all_ap) else 0.0
|
1842 |
+
|
1843 |
+
def mean_results(self):
|
1844 |
+
"""Mean of results, return mp, mr, map50, map"""
|
1845 |
+
return [self.mp, self.mr, self.map50, self.map]
|
1846 |
+
|
1847 |
+
def class_result(self, i):
|
1848 |
+
"""class-aware result, return p[i], r[i], ap50[i], ap[i]"""
|
1849 |
+
return self.p[i], self.r[i], self.ap50[i], self.ap[i]
|
1850 |
+
|
1851 |
+
@property
|
1852 |
+
def maps(self):
|
1853 |
+
"""mAP of each class"""
|
1854 |
+
maps = np.zeros(self.nc) + self.map
|
1855 |
+
for i, c in enumerate(self.ap_class_index):
|
1856 |
+
maps[c] = self.ap[i]
|
1857 |
+
return maps
|
1858 |
+
|
1859 |
+
def fitness(self):
|
1860 |
+
# Model fitness as a weighted combination of metrics
|
1861 |
+
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
|
1862 |
+
return (np.array(self.mean_results()) * w).sum()
|
1863 |
+
|
1864 |
+
def update(self, results):
|
1865 |
+
"""
|
1866 |
+
Args:
|
1867 |
+
results: tuple(p, r, ap, f1, ap_class)
|
1868 |
+
"""
|
1869 |
+
self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results
|
1870 |
+
|
1871 |
+
|
1872 |
+
class DetMetrics:
|
1873 |
+
|
1874 |
+
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
|
1875 |
+
self.save_dir = save_dir
|
1876 |
+
self.plot = plot
|
1877 |
+
self.names = names
|
1878 |
+
self.box = Metric()
|
1879 |
+
|
1880 |
+
def process(self, tp, conf, pred_cls, target_cls):
|
1881 |
+
results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
|
1882 |
+
names=self.names)[2:]
|
1883 |
+
self.box.nc = len(self.names)
|
1884 |
+
self.box.update(results)
|
1885 |
+
|
1886 |
+
@property
|
1887 |
+
def keys(self):
|
1888 |
+
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
1889 |
+
|
1890 |
+
def mean_results(self):
|
1891 |
+
return self.box.mean_results()
|
1892 |
+
|
1893 |
+
def class_result(self, i):
|
1894 |
+
return self.box.class_result(i)
|
1895 |
+
|
1896 |
+
@property
|
1897 |
+
def maps(self):
|
1898 |
+
return self.box.maps
|
1899 |
+
|
1900 |
+
@property
|
1901 |
+
def fitness(self):
|
1902 |
+
return self.box.fitness()
|
1903 |
+
|
1904 |
+
@property
|
1905 |
+
def ap_class_index(self):
|
1906 |
+
return self.box.ap_class_index
|
1907 |
+
|
1908 |
+
@property
|
1909 |
+
def results_dict(self):
|
1910 |
+
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
1911 |
+
|
1912 |
+
|
1913 |
+
def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
1914 |
+
"""
|
1915 |
+
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
1916 |
+
|
1917 |
+
If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
|
1918 |
+
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
|
1919 |
+
number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
|
1920 |
+
directory if it does not already exist.
|
1921 |
+
|
1922 |
+
Args:
|
1923 |
+
path (str or pathlib.Path): Path to increment.
|
1924 |
+
exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False.
|
1925 |
+
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string.
|
1926 |
+
mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False.
|
1927 |
+
|
1928 |
+
Returns:
|
1929 |
+
pathlib.Path: Incremented path.
|
1930 |
+
"""
|
1931 |
+
path = Path(path) # os-agnostic
|
1932 |
+
if path.exists() and not exist_ok:
|
1933 |
+
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
|
1934 |
+
|
1935 |
+
# Method 1
|
1936 |
+
for n in range(2, 9999):
|
1937 |
+
p = f'{path}{sep}{n}{suffix}' # increment path
|
1938 |
+
if not os.path.exists(p): #
|
1939 |
+
break
|
1940 |
+
path = Path(p)
|
1941 |
+
|
1942 |
+
if mkdir:
|
1943 |
+
path.mkdir(parents=True, exist_ok=True) # make directory
|
1944 |
+
|
1945 |
+
return path
|
1946 |
+
|
1947 |
+
|
1948 |
+
def cfg2dict(cfg):
|
1949 |
+
"""
|
1950 |
+
Convert a configuration object to a dictionary.
|
1951 |
+
|
1952 |
+
This function converts a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
|
1953 |
+
|
1954 |
+
Inputs:
|
1955 |
+
cfg (str) or (Path) or (SimpleNamespace): Configuration object to be converted to a dictionary.
|
1956 |
+
|
1957 |
+
Returns:
|
1958 |
+
cfg (dict): Configuration object in dictionary format.
|
1959 |
+
"""
|
1960 |
+
if isinstance(cfg, (str, Path)):
|
1961 |
+
cfg = yaml_load(cfg) # load dict
|
1962 |
+
elif isinstance(cfg, SimpleNamespace):
|
1963 |
+
cfg = vars(cfg) # convert to dict
|
1964 |
+
return cfg
|
1965 |
+
|
1966 |
+
|
1967 |
+
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = None, overrides: Dict = None):
|
1968 |
+
"""
|
1969 |
+
Load and merge configuration data from a file or dictionary.
|
1970 |
+
|
1971 |
+
Args:
|
1972 |
+
cfg (str) or (Path) or (Dict) or (SimpleNamespace): Configuration data.
|
1973 |
+
overrides (str) or (Dict), optional: Overrides in the form of a file name or a dictionary. Default is None.
|
1974 |
+
|
1975 |
+
Returns:
|
1976 |
+
(SimpleNamespace): Training arguments namespace.
|
1977 |
+
"""
|
1978 |
+
cfg = cfg2dict(cfg)
|
1979 |
+
|
1980 |
+
# Merge overrides
|
1981 |
+
if overrides:
|
1982 |
+
overrides = cfg2dict(overrides)
|
1983 |
+
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
|
1984 |
+
|
1985 |
+
# Special handling for numeric project/names
|
1986 |
+
for k in 'project', 'name':
|
1987 |
+
if k in cfg and isinstance(cfg[k], (int, float)):
|
1988 |
+
cfg[k] = str(cfg[k])
|
1989 |
+
|
1990 |
+
# Type and Value checks
|
1991 |
+
for k, v in cfg.items():
|
1992 |
+
if v is not None: # None values may be from optional args
|
1993 |
+
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
|
1994 |
+
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
1995 |
+
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
1996 |
+
elif k in CFG_FRACTION_KEYS:
|
1997 |
+
if not isinstance(v, (int, float)):
|
1998 |
+
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
1999 |
+
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
2000 |
+
if not (0.0 <= v <= 1.0):
|
2001 |
+
raise ValueError(f"'{k}={v}' is an invalid value. "
|
2002 |
+
f"Valid '{k}' values are between 0.0 and 1.0.")
|
2003 |
+
elif k in CFG_INT_KEYS and not isinstance(v, int):
|
2004 |
+
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
2005 |
+
f"'{k}' must be an int (i.e. '{k}=0')")
|
2006 |
+
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
|
2007 |
+
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
2008 |
+
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
|
2009 |
+
|
2010 |
+
# Return instance
|
2011 |
+
return IterableSimpleNamespace(**cfg)
|
2012 |
+
|
2013 |
+
|
2014 |
+
def clip_boxes(boxes, shape):
|
2015 |
+
"""
|
2016 |
+
It takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the
|
2017 |
+
shape
|
2018 |
+
|
2019 |
+
Args:
|
2020 |
+
boxes (torch.Tensor): the bounding boxes to clip
|
2021 |
+
shape (tuple): the shape of the image
|
2022 |
+
"""
|
2023 |
+
if isinstance(boxes, torch.Tensor): # faster individually
|
2024 |
+
boxes[..., 0].clamp_(0, shape[1]) # x1
|
2025 |
+
boxes[..., 1].clamp_(0, shape[0]) # y1
|
2026 |
+
boxes[..., 2].clamp_(0, shape[1]) # x2
|
2027 |
+
boxes[..., 3].clamp_(0, shape[0]) # y2
|
2028 |
+
else: # np.array (faster grouped)
|
2029 |
+
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
|
2030 |
+
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
|
2031 |
+
|
2032 |
+
|
2033 |
+
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
2034 |
+
"""
|
2035 |
+
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
|
2036 |
+
(img1_shape) to the shape of a different image (img0_shape).
|
2037 |
+
|
2038 |
+
Args:
|
2039 |
+
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
|
2040 |
+
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
|
2041 |
+
img0_shape (tuple): the shape of the target image, in the format of (height, width).
|
2042 |
+
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
|
2043 |
+
calculated based on the size difference between the two images.
|
2044 |
+
|
2045 |
+
Returns:
|
2046 |
+
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
|
2047 |
+
"""
|
2048 |
+
if ratio_pad is None: # calculate from img0_shape
|
2049 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
2050 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
2051 |
+
else:
|
2052 |
+
gain = ratio_pad[0][0]
|
2053 |
+
pad = ratio_pad[1]
|
2054 |
+
|
2055 |
+
boxes[..., [0, 2]] -= pad[0] # x padding
|
2056 |
+
boxes[..., [1, 3]] -= pad[1] # y padding
|
2057 |
+
boxes[..., :4] /= gain
|
2058 |
+
clip_boxes(boxes, img0_shape)
|
2059 |
+
return boxes
|
2060 |
+
|
2061 |
+
|
2062 |
+
def exif_size(img):
|
2063 |
+
# Returns exif-corrected PIL size
|
2064 |
+
s = img.size # (width, height)
|
2065 |
+
with contextlib.suppress(Exception):
|
2066 |
+
rotation = dict(img._getexif().items())[orientation]
|
2067 |
+
if rotation in [6, 8]: # rotation 270 or 90
|
2068 |
+
s = (s[1], s[0])
|
2069 |
+
return s
|
2070 |
+
|
2071 |
+
|
2072 |
+
def verify_image_label(args):
|
2073 |
+
# Verify one image-label pair
|
2074 |
+
im_file, lb_file, prefix, keypoint, num_cls = args
|
2075 |
+
# number (missing, found, empty, corrupt), message, segments, keypoints
|
2076 |
+
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
|
2077 |
+
try:
|
2078 |
+
# verify images
|
2079 |
+
im = Image.open(im_file)
|
2080 |
+
im.verify() # PIL verify
|
2081 |
+
shape = exif_size(im) # image size
|
2082 |
+
shape = (shape[1], shape[0]) # hw
|
2083 |
+
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
2084 |
+
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
2085 |
+
if im.format.lower() in ("jpg", "jpeg"):
|
2086 |
+
with open(im_file, "rb") as f:
|
2087 |
+
f.seek(-2, 2)
|
2088 |
+
|
2089 |
+
# verify labels
|
2090 |
+
if os.path.isfile(lb_file):
|
2091 |
+
nf = 1 # label found
|
2092 |
+
with open(lb_file) as f:
|
2093 |
+
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
2094 |
+
if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
|
2095 |
+
classes = np.array([x[0] for x in lb], dtype=np.float32)
|
2096 |
+
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
|
2097 |
+
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
|
2098 |
+
lb = np.array(lb, dtype=np.float32)
|
2099 |
+
nl = len(lb)
|
2100 |
+
if nl:
|
2101 |
+
if keypoint:
|
2102 |
+
assert lb.shape[1] == 56, "labels require 56 columns each"
|
2103 |
+
assert (lb[:, 5::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
|
2104 |
+
assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
|
2105 |
+
kpts = np.zeros((lb.shape[0], 39))
|
2106 |
+
for i in range(len(lb)):
|
2107 |
+
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, 3)) # remove occlusion param from GT
|
2108 |
+
kpts[i] = np.hstack((lb[i, :5], kpt))
|
2109 |
+
lb = kpts
|
2110 |
+
assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter"
|
2111 |
+
else:
|
2112 |
+
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
|
2113 |
+
assert (lb[:, 1:] <= 1).all(), \
|
2114 |
+
f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
|
2115 |
+
# All labels
|
2116 |
+
max_cls = int(lb[:, 0].max()) # max label count
|
2117 |
+
assert max_cls <= num_cls, \
|
2118 |
+
f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
|
2119 |
+
f'Possible class labels are 0-{num_cls - 1}'
|
2120 |
+
assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
|
2121 |
+
_, i = np.unique(lb, axis=0, return_index=True)
|
2122 |
+
if len(i) < nl: # duplicate row check
|
2123 |
+
lb = lb[i] # remove duplicates
|
2124 |
+
if segments:
|
2125 |
+
segments = [segments[x] for x in i]
|
2126 |
+
msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
|
2127 |
+
else:
|
2128 |
+
ne = 1 # label empty
|
2129 |
+
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
|
2130 |
+
else:
|
2131 |
+
nm = 1 # label missing
|
2132 |
+
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
|
2133 |
+
if keypoint:
|
2134 |
+
keypoints = lb[:, 5:].reshape(-1, 17, 2)
|
2135 |
+
lb = lb[:, :5]
|
2136 |
+
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
2137 |
+
except Exception as e:
|
2138 |
+
nc = 1
|
2139 |
+
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
|
2140 |
+
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
yolov8m_qat.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3b770e88b358ad24cc60e7b8bbc00b09bb1e0308f65f45cdcea2a1dfc1301077
|
3 |
+
size 103874610
|