zhengrongzhang commited on
Commit
e6c79f4
1 Parent(s): ae727cb

init model

Browse files
Files changed (12) hide show
  1. README.md +133 -0
  2. anchors.npy +3 -0
  3. coco.yaml +97 -0
  4. default.yaml +119 -0
  5. demo.jpg +0 -0
  6. general_json2yolo.py +183 -0
  7. onnx_eval.py +284 -0
  8. onnx_inference.py +148 -0
  9. requirements.txt +43 -0
  10. strides.npy +3 -0
  11. utils.py +2140 -0
  12. yolov8m_qat.onnx +3 -0
README.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - RyzenAI
5
+ - object-detection
6
+ - vision
7
+ - YOLO
8
+ - Pytorch
9
+ datasets:
10
+ - COCO
11
+ metrics:
12
+ - mAP
13
+ ---
14
+ # YOLOv8m model trained on COCO
15
+
16
+ YOLOv8m is the medium version of YOLOv8 model trained on COCO object detection (118k annotated images) at resolution 640x640. It was released in [https://github.com/ultralytics/ultralytics](https://github.com/ultralytics/ultralytics).
17
+
18
+ We develop a modified version that could be supported by [AMD Ryzen AI](https://onnxruntime.ai/docs/execution-providers/Vitis-AI-ExecutionProvider.html).
19
+
20
+
21
+ ## Model description
22
+
23
+ Ultralytics YOLOv8 is a cutting-edge, state-of-the-art (SOTA) model that builds upon the success of previous YOLO versions and introduces new features and improvements to further boost performance and flexibility. YOLOv8 is designed to be fast, accurate, and easy to use, making it an excellent choice for a wide range of object detection and tracking, instance segmentation, image classification and pose estimation tasks.
24
+
25
+
26
+ ## Intended uses & limitations
27
+
28
+ You can use the raw model for object detection. See the [model hub](https://huggingface.co/models?search=amd/yolov8) to look for all available YOLOv8 models.
29
+
30
+
31
+ ## How to use
32
+
33
+ ### Installation
34
+
35
+ Follow [Ryzen AI Installation](https://ryzenai.docs.amd.com/en/latest/inst.html) to prepare the environment for Ryzen AI.
36
+ Run the following script to install pre-requisites for this model.
37
+ ```bash
38
+ pip install -r requirements.txt
39
+ ```
40
+
41
+
42
+ ### Data Preparation (optional: for accuracy evaluation)
43
+
44
+ The dataset MSCOCO2017 contains 118287 images for training and 5000 images for validation.
45
+
46
+ Download COCO dataset and create directories in your code like this:
47
+ ```plain
48
+ └── datasets
49
+ └── coco
50
+ ├── annotations
51
+ | ├── instances_val2017.json
52
+ | └── ...
53
+ ├── labels
54
+ | ├── val2017
55
+ | | ├── 000000000139.txt
56
+ | ├── 000000000285.txt
57
+ | └── ...
58
+ ├── images
59
+ | ├── val2017
60
+ | | ├── 000000000139.jpg
61
+ | ├── 000000000285.jpg
62
+ └── val2017.txt
63
+ ```
64
+ 1. put the val2017 image folder under images directory or use a softlink
65
+ 2. the labels folder and val2017.txt above are generate by **general_json2yolo.py**
66
+ 3. modify the coco.yaml like this:
67
+ ```markdown
68
+ path: /path/to/your/datasets/coco # dataset root dir
69
+ train: train2017.txt # train images (relative to 'path') 118287 images
70
+ val: val2017.txt # val images (relative to 'path') 5000 images
71
+ ```
72
+
73
+
74
+ ### Test & Evaluation
75
+
76
+ - Code snippet from [`onnx_inference.py`](onnx_inference.py) on how to use
77
+ ```python
78
+ args = make_parser().parse_args()
79
+ source = args.image_path
80
+ dataset = LoadImages(
81
+ source, imgsz=imgsz, stride=32, auto=False, transforms=None, vid_stride=1
82
+ )
83
+ onnx_weight = args.model
84
+ onnx_model = onnxruntime.InferenceSession(onnx_weight)
85
+ for batch in dataset:
86
+ path, im, im0s, vid_cap, s = batch
87
+ im = preprocess(im)
88
+ if len(im.shape) == 3:
89
+ im = im[None]
90
+ outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: im.cpu().numpy()})
91
+ outputs = [torch.tensor(item) for item in outputs]
92
+ preds = post_process(outputs)
93
+ preds = non_max_suppression(
94
+ preds, 0.25, 0.7, agnostic=False, max_det=300, classes=None
95
+ )
96
+ plot_images(
97
+ im,
98
+ *output_to_target(preds, max_det=15),
99
+ source,
100
+ fname=args.output_path,
101
+ names=names,
102
+ )
103
+
104
+ ```
105
+
106
+ - Run inference for a single image
107
+ ```python
108
+ python onnx_inference.py -m ./yolov8m_qat.onnx -i /Path/To/Your/Image --ipu --provider_config /Path/To/Your/Provider_config
109
+ ```
110
+ *Note: __vaip_config.json__ is located at the setup package of Ryzen AI (refer to [Installation](#installation))*
111
+ - Test accuracy of the quantized model
112
+ ```python
113
+ python onnx_eval.py -m ./yolov8m_qat.onnx --ipu --provider_config /Path/To/Your/Provider_config
114
+ ```
115
+
116
+ ### Performance
117
+
118
+ |Metric |Accuracy on IPU|
119
+ | :----: | :----: |
120
+ |AP\@0.50:0.95|0.486|
121
+
122
+
123
+ ```bibtex
124
+ @software{yolov8_ultralytics,
125
+ author = {Glenn Jocher and Ayush Chaurasia and Jing Qiu},
126
+ title = {Ultralytics YOLOv8},
127
+ version = {8.0.0},
128
+ year = {2023},
129
+ url = {https://github.com/ultralytics/ultralytics},
130
+ orcid = {0000-0001-5950-6979, 0000-0002-7603-6750, 0000-0003-3783-7069},
131
+ license = {AGPL-3.0}
132
+ }
133
+ ```
anchors.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b677c029271bca12b3e4311e468d316962e31d1a2b7bae3eed90c9c65738fa31
3
+ size 67328
coco.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, GPL-3.0 license
2
+ # COCO 2017 dataset http://cocodataset.org by Microsoft
3
+ # Example usage: python train.py --data coco.yaml
4
+ # parent
5
+ # ├── yolov5
6
+ # └── datasets
7
+ # └── coco ← downloads here (20.1 GB)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ./datasets/coco # dataset root dir
12
+ train: train2017.txt # train images (relative to 'path') 118287 images
13
+ val: val2017.txt # val images (relative to 'path') 5000 images
14
+ #test: val2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
15
+
16
+ # Classes
17
+ names:
18
+ 0: person
19
+ 1: bicycle
20
+ 2: car
21
+ 3: motorcycle
22
+ 4: airplane
23
+ 5: bus
24
+ 6: train
25
+ 7: truck
26
+ 8: boat
27
+ 9: traffic light
28
+ 10: fire hydrant
29
+ 11: stop sign
30
+ 12: parking meter
31
+ 13: bench
32
+ 14: bird
33
+ 15: cat
34
+ 16: dog
35
+ 17: horse
36
+ 18: sheep
37
+ 19: cow
38
+ 20: elephant
39
+ 21: bear
40
+ 22: zebra
41
+ 23: giraffe
42
+ 24: backpack
43
+ 25: umbrella
44
+ 26: handbag
45
+ 27: tie
46
+ 28: suitcase
47
+ 29: frisbee
48
+ 30: skis
49
+ 31: snowboard
50
+ 32: sports ball
51
+ 33: kite
52
+ 34: baseball bat
53
+ 35: baseball glove
54
+ 36: skateboard
55
+ 37: surfboard
56
+ 38: tennis racket
57
+ 39: bottle
58
+ 40: wine glass
59
+ 41: cup
60
+ 42: fork
61
+ 43: knife
62
+ 44: spoon
63
+ 45: bowl
64
+ 46: banana
65
+ 47: apple
66
+ 48: sandwich
67
+ 49: orange
68
+ 50: broccoli
69
+ 51: carrot
70
+ 52: hot dog
71
+ 53: pizza
72
+ 54: donut
73
+ 55: cake
74
+ 56: chair
75
+ 57: couch
76
+ 58: potted plant
77
+ 59: bed
78
+ 60: dining table
79
+ 61: toilet
80
+ 62: tv
81
+ 63: laptop
82
+ 64: mouse
83
+ 65: remote
84
+ 66: keyboard
85
+ 67: cell phone
86
+ 68: microwave
87
+ 69: oven
88
+ 70: toaster
89
+ 71: sink
90
+ 72: refrigerator
91
+ 73: book
92
+ 74: clock
93
+ 75: vase
94
+ 76: scissors
95
+ 77: teddy bear
96
+ 78: hair drier
97
+ 79: toothbrush
default.yaml ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, GPL-3.0 license
2
+ # Default training settings and hyperparameters for medium-augmentation COCO training
3
+
4
+ task: detect # inference task, i.e. detect, segment, classify
5
+ mode: train # YOLO mode, i.e. train, val, predict, export
6
+
7
+ # Train settings -------------------------------------------------------------------------------------------------------
8
+ model: # path to model file, i.e. yolov8n.pt, yolov8n.yaml
9
+ data: "./coco.yaml" # path to data file, i.e. i.e. coco128.yaml
10
+ epochs: 100 # number of epochs to train for
11
+ patience: 50 # epochs to wait for no observable improvement for early stopping of training
12
+ batch: 1 # number of images per batch (-1 for AutoBatch)
13
+ imgsz: 640 # size of input images as integer or w,h
14
+ save: True # save train checkpoints and predict results
15
+ cache: False # True/ram, disk or False. Use cache for data loading
16
+ device: # device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
17
+ workers: 8 # number of worker threads for data loading (per RANK if DDP)
18
+ project: # project name
19
+ name: # experiment name
20
+ exist_ok: False # whether to overwrite existing experiment
21
+ pretrained: False # whether to use a pretrained model
22
+ optimizer: SGD # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
23
+ verbose: True # whether to print verbose output
24
+ seed: 0 # random seed for reproducibility
25
+ deterministic: True # whether to enable deterministic mode
26
+ single_cls: False # train multi-class data as single-class
27
+ image_weights: False # use weighted image selection for training
28
+ rect: False # support rectangular training if mode='train', support rectangular evaluation if mode='val'
29
+ cos_lr: False # use cosine learning rate scheduler
30
+ close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
31
+ resume: False # resume training from last checkpoint
32
+ min_memory: False # minimize memory footprint loss function, choices=[False, True, <roll_out_thr>]
33
+ sync_bn: False # convert batchnorm to syncbatchnorm in model
34
+ nndct_quant: False # True for quant model
35
+ quant_mode: 'test' # calib or test
36
+ dump_xmodel: False # True for dump xmodel
37
+ dump_onnx: False # True for dump onnx
38
+ onnx_weight: "./yolov8m_qat.onnx"
39
+ onnx_runtime: False
40
+ ipu: False
41
+ provider_config: ''
42
+ # Segmentation
43
+ overlap_mask: True # masks should overlap during training (segment train only)
44
+ mask_ratio: 4 # mask downsample ratio (segment train only)
45
+ # Classification
46
+ dropout: 0.0 # use dropout regularization (classify train only)
47
+
48
+ # Val/Test settings ----------------------------------------------------------------------------------------------------
49
+ val: True # validate/test during training
50
+ save_json: False # save results to JSON file
51
+ save_hybrid: False # save hybrid version of labels (labels + additional predictions)
52
+ conf: # object confidence threshold for detection (default 0.25 predict, 0.001 val)
53
+ iou: 0.7 # intersection over union (IoU) threshold for NMS
54
+ max_det: 300 # maximum number of detections per image
55
+ half: False # use half precision (FP16)
56
+ dnn: False # use OpenCV DNN for ONNX inference
57
+ plots: True # save plots during train/val
58
+
59
+ # Prediction settings --------------------------------------------------------------------------------------------------
60
+ source: # source directory for images or videos
61
+ show: True # show results if possible
62
+ save_txt: True # save results as .txt file
63
+ save_conf: False # save results with confidence scores
64
+ save_crop: False # save cropped images with results
65
+ hide_labels: False # hide labels
66
+ hide_conf: False # hide confidence scores
67
+ vid_stride: 1 # video frame-rate stride
68
+ line_thickness: 3 # bounding box thickness (pixels)
69
+ visualize: False # visualize model features
70
+ augment: False # apply image augmentation to prediction sources
71
+ agnostic_nms: False # class-agnostic NMS
72
+ classes: # filter results by class, i.e. class=0, or class=[0,2,3]
73
+ retina_masks: False # use high-resolution segmentation masks
74
+ boxes: True # Show boxes in segmentation predictions
75
+
76
+ # Export settings ------------------------------------------------------------------------------------------------------
77
+ format: torchscript # format to export to
78
+ keras: False # use Keras
79
+ optimize: False # TorchScript: optimize for mobile
80
+ int8: False # CoreML/TF INT8 quantization
81
+ dynamic: False # ONNX/TF/TensorRT: dynamic axes
82
+ simplify: False # ONNX: simplify model
83
+ opset: # ONNX: opset version (optional)
84
+ workspace: 4 # TensorRT: workspace size (GB)
85
+ nms: False # CoreML: add NMS
86
+
87
+ # Hyperparameters ------------------------------------------------------------------------------------------------------
88
+ lr0: 0.01 # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
89
+ lrf: 0.01 # final learning rate (lr0 * lrf)
90
+ momentum: 0.937 # SGD momentum/Adam beta1
91
+ weight_decay: 0.0005 # optimizer weight decay 5e-4
92
+ warmup_epochs: 3.0 # warmup epochs (fractions ok)
93
+ warmup_momentum: 0.8 # warmup initial momentum
94
+ warmup_bias_lr: 0.1 # warmup initial bias lr
95
+ box: 7.5 # box loss gain
96
+ cls: 0.5 # cls loss gain (scale with pixels)
97
+ dfl: 1.5 # dfl loss gain
98
+ fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
99
+ label_smoothing: 0.0 # label smoothing (fraction)
100
+ nbs: 64 # nominal batch size
101
+ hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
102
+ hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
103
+ hsv_v: 0.4 # image HSV-Value augmentation (fraction)
104
+ degrees: 0.0 # image rotation (+/- deg)
105
+ translate: 0.1 # image translation (+/- fraction)
106
+ scale: 0.5 # image scale (+/- gain)
107
+ shear: 0.0 # image shear (+/- deg)
108
+ perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
109
+ flipud: 0.0 # image flip up-down (probability)
110
+ fliplr: 0.5 # image flip left-right (probability)
111
+ mosaic: 1.0 # image mosaic (probability)
112
+ mixup: 0.0 # image mixup (probability)
113
+ copy_paste: 0.0 # segment copy-paste (probability)
114
+
115
+ # Custom config.yaml ---------------------------------------------------------------------------------------------------
116
+ cfg: # for overriding defaults.yaml
117
+
118
+ # Debug, do not modify -------------------------------------------------------------------------------------------------
119
+ v5loader: False # use legacy YOLOv5 dataloader
demo.jpg ADDED
general_json2yolo.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import defaultdict
3
+ from pathlib import Path
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import sys
7
+ import pathlib
8
+
9
+ CURRENT_DIR = pathlib.Path(__file__).parent
10
+ sys.path.append(str(CURRENT_DIR))
11
+
12
+
13
+ def make_dirs(dir="./datasets/coco"):
14
+ # Create folders
15
+ dir = Path(dir)
16
+ for p in [dir / "labels"]:
17
+ p.mkdir(parents=True, exist_ok=True) # make dir
18
+ return dir
19
+
20
+
21
+ def coco91_to_coco80_class(): # converts 80-index (val2014) to 91-index (paper)
22
+ # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
23
+ x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None,
24
+ None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
25
+ 51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
26
+ None, 73, 74, 75, 76, 77, 78, 79, None]
27
+ return x
28
+
29
+
30
+ def convert_coco_json(
31
+ json_dir="../coco/annotations/", use_segments=False, cls91to80=False
32
+ ):
33
+ save_dir = make_dirs() # output directory
34
+ coco80 = coco91_to_coco80_class()
35
+
36
+ # Import json
37
+ for json_file in sorted(Path(json_dir).resolve().glob("*.json")):
38
+ if not str(json_file).endswith("instances_val2017.json"):
39
+ continue
40
+ fn = (
41
+ Path(save_dir) / "labels" / json_file.stem.replace("instances_", "")
42
+ ) # folder name
43
+ fn.mkdir()
44
+ with open(json_file) as f:
45
+ data = json.load(f)
46
+
47
+ # Create image dict
48
+ images = {"%g" % x["id"]: x for x in data["images"]}
49
+ # Create image-annotations dict
50
+ imgToAnns = defaultdict(list)
51
+ for ann in data["annotations"]:
52
+ imgToAnns[ann["image_id"]].append(ann)
53
+
54
+ txt_file = open(Path(save_dir / "val2017").with_suffix(".txt"), "a")
55
+ # Write labels file
56
+ for img_id, anns in tqdm(imgToAnns.items(), desc=f"Annotations {json_file}"):
57
+ img = images["%g" % img_id]
58
+ h, w, f = img["height"], img["width"], img["file_name"]
59
+ bboxes = []
60
+ segments = []
61
+
62
+ txt_file.write(
63
+ "./images/" + "/".join(img["coco_url"].split("/")[-2:]) + "\n"
64
+ )
65
+ for ann in anns:
66
+ if ann["iscrowd"]:
67
+ continue
68
+ # The COCO box format is [top left x, top left y, width, height]
69
+ box = np.array(ann["bbox"], dtype=np.float64)
70
+ box[:2] += box[2:] / 2 # xy top-left corner to center
71
+ box[[0, 2]] /= w # normalize x
72
+ box[[1, 3]] /= h # normalize y
73
+ if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0
74
+ continue
75
+
76
+ cls = (
77
+ coco80[ann["category_id"] - 1]
78
+ if cls91to80
79
+ else ann["category_id"] - 1
80
+ ) # class
81
+ box = [cls] + box.tolist()
82
+ if box not in bboxes:
83
+ bboxes.append(box)
84
+ # Segments
85
+ if use_segments:
86
+ if len(ann["segmentation"]) > 1:
87
+ s = merge_multi_segment(ann["segmentation"])
88
+ s = (
89
+ (np.concatenate(s, axis=0) / np.array([w, h]))
90
+ .reshape(-1)
91
+ .tolist()
92
+ )
93
+ else:
94
+ s = [
95
+ j for i in ann["segmentation"] for j in i
96
+ ] # all segments concatenated
97
+ s = (
98
+ (np.array(s).reshape(-1, 2) / np.array([w, h]))
99
+ .reshape(-1)
100
+ .tolist()
101
+ )
102
+ s = [cls] + s
103
+ if s not in segments:
104
+ segments.append(s)
105
+
106
+ # Write
107
+ with open((fn / f).with_suffix(".txt"), "a") as file:
108
+ for i in range(len(bboxes)):
109
+ line = (
110
+ *(segments[i] if use_segments else bboxes[i]),
111
+ ) # cls, box or segments
112
+ file.write(("%g " * len(line)).rstrip() % line + "\n")
113
+ txt_file.close()
114
+
115
+
116
+ def min_index(arr1, arr2):
117
+ """Find a pair of indexes with the shortest distance.
118
+ Args:
119
+ arr1: (N, 2).
120
+ arr2: (M, 2).
121
+ Return:
122
+ a pair of indexes(tuple).
123
+ """
124
+ dis = ((arr1[:, None, :] - arr2[None, :, :]) ** 2).sum(-1)
125
+ return np.unravel_index(np.argmin(dis, axis=None), dis.shape)
126
+
127
+
128
+ def merge_multi_segment(segments):
129
+ """Merge multi segments to one list.
130
+ Find the coordinates with min distance between each segment,
131
+ then connect these coordinates with one thin line to merge all
132
+ segments into one.
133
+
134
+ Args:
135
+ segments(List(List)): original segmentations in coco's json file.
136
+ like [segmentation1, segmentation2,...],
137
+ each segmentation is a list of coordinates.
138
+ """
139
+ s = []
140
+ segments = [np.array(i).reshape(-1, 2) for i in segments]
141
+ idx_list = [[] for _ in range(len(segments))]
142
+
143
+ # record the indexes with min distance between each segment
144
+ for i in range(1, len(segments)):
145
+ idx1, idx2 = min_index(segments[i - 1], segments[i])
146
+ idx_list[i - 1].append(idx1)
147
+ idx_list[i].append(idx2)
148
+
149
+ # use two round to connect all the segments
150
+ for k in range(2):
151
+ # forward connection
152
+ if k == 0:
153
+ for i, idx in enumerate(idx_list):
154
+ # middle segments have two indexes
155
+ # reverse the index of middle segments
156
+ if len(idx) == 2 and idx[0] > idx[1]:
157
+ idx = idx[::-1]
158
+ segments[i] = segments[i][::-1, :]
159
+
160
+ segments[i] = np.roll(segments[i], -idx[0], axis=0)
161
+ segments[i] = np.concatenate([segments[i], segments[i][:1]])
162
+ # deal with the first segment and the last one
163
+ if i in [0, len(idx_list) - 1]:
164
+ s.append(segments[i])
165
+ else:
166
+ idx = [0, idx[1] - idx[0]]
167
+ s.append(segments[i][idx[0] : idx[1] + 1])
168
+
169
+ else:
170
+ for i in range(len(idx_list) - 1, -1, -1):
171
+ if i not in [0, len(idx_list) - 1]:
172
+ idx = idx_list[i]
173
+ nidx = abs(idx[1] - idx[0])
174
+ s.append(segments[i][nidx:])
175
+ return s
176
+
177
+
178
+ if __name__ == "__main__":
179
+ convert_coco_json(
180
+ "./datasets/coco/annotations", # directory with *.json
181
+ use_segments=True,
182
+ cls91to80=True,
183
+ )
onnx_eval.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ import onnxruntime
8
+ from utils import check_det_dataset, yaml_load, IterableSimpleNamespace, build_dataloader, post_process, xyxy2xywh, LOGGER, \
9
+ DetMetrics, increment_path, get_cfg, smart_inference_mode, box_iou, TQDM_BAR_FORMAT, scale_boxes, non_max_suppression, xywh2xyxy
10
+
11
+ # Default configuration
12
+ DEFAULT_CFG_DICT = yaml_load("./default.yaml")
13
+ for k, v in DEFAULT_CFG_DICT.items():
14
+ if isinstance(v, str) and v.lower() == 'none':
15
+ DEFAULT_CFG_DICT[k] = None
16
+ DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
17
+ DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
18
+ import sys
19
+ import pathlib
20
+ CURRENT_DIR = pathlib.Path(__file__).parent
21
+ sys.path.append(str(CURRENT_DIR))
22
+
23
+
24
+ class DetectionValidator:
25
+
26
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
27
+ self.dataloader = dataloader
28
+ self.pbar = pbar
29
+ self.logger = LOGGER
30
+ self.args = args
31
+ self.model = None
32
+ self.data = None
33
+ self.device = None
34
+ self.batch_i = None
35
+ self.speed = None
36
+ self.jdict = None
37
+ self.args.task = 'detect'
38
+ project = Path("./runs") / self.args.task
39
+ self.save_dir = save_dir or increment_path(Path(project),
40
+ exist_ok=True)
41
+ (self.save_dir / 'labels').mkdir(parents=True, exist_ok=True)
42
+ self.args.conf = 0.001 # default conf=0.001
43
+ self.is_coco = False
44
+ self.class_map = None
45
+ self.metrics = DetMetrics(save_dir=self.save_dir)
46
+ self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
47
+ self.niou = self.iouv.numel()
48
+
49
+ @smart_inference_mode()
50
+ def __call__(self, trainer=None, model=None):
51
+ """
52
+ Supports validation of a pre-trained model if passed or a model being trained
53
+ if trainer is passed (trainer gets priority).
54
+ """
55
+ self.device = torch.device('cpu')
56
+ onnx_weight = self.args.onnx_weight
57
+ if isinstance(onnx_weight, list):
58
+ onnx_weight = onnx_weight[0]
59
+ if self.args.ipu:
60
+ providers = ["VitisAIExecutionProvider"]
61
+ provider_options = [{"config_file": self.args.provider_config}]
62
+ onnx_model = onnxruntime.InferenceSession(onnx_weight, providers=providers, provider_options=provider_options)
63
+ else:
64
+ onnx_model = onnxruntime.InferenceSession(onnx_weight)
65
+ self.data = check_det_dataset(self.args.data)
66
+ self.args.rect = False
67
+ self.dataloader = self.dataloader or self.get_dataloader(self.data.get("val") or self.data.get("test"), self.args.batch)
68
+ total = len(self.dataloader)
69
+ n_batches = len(self.dataloader)
70
+ desc = self.get_desc()
71
+ bar = tqdm(self.dataloader, desc, total, bar_format=TQDM_BAR_FORMAT)
72
+ self.init_metrics()
73
+ self.jdict = [] # empty before each val
74
+
75
+ for batch_i, batch in enumerate(bar):
76
+ self.batch_i = batch_i
77
+ # pre-process
78
+ batch = self.preprocess(batch)
79
+
80
+ # inference
81
+ outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: batch["img"].cpu().numpy()})
82
+ outputs = [torch.tensor(item).to(self.device) for item in outputs]
83
+ preds = post_process(outputs)
84
+
85
+ # pre-process predictions
86
+ preds = self.postprocess(preds)
87
+ self.update_metrics(preds, batch)
88
+ stats = self.get_stats()
89
+ self.print_results()
90
+ if self.args.save_json and self.jdict:
91
+ with open(str(self.save_dir / "predictions.json"), 'w') as f:
92
+ self.logger.info(f"Saving {f.name}...")
93
+ json.dump(self.jdict, f) # flatten and save
94
+ stats = self.eval_json(stats) # update stats
95
+ return stats
96
+
97
+ def get_dataloader(self, dataset_path, batch_size):
98
+ # calculate stride - check if model is initialized
99
+ return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=32, names=self.data['names'], mode="val")[0]
100
+
101
+ def get_desc(self):
102
+ return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)")
103
+
104
+ def init_metrics(self):
105
+ self.is_coco = True
106
+ self.class_map = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
107
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
108
+ 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
109
+ self.args.save_json = True
110
+ self.nc = 80
111
+ classnames = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
112
+ 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
113
+ 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
114
+ 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
115
+ 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
116
+ 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
117
+ 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
118
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
119
+ 'hair drier', 'toothbrush']
120
+ self.names = {k: classnames[k] for k in range(80)}
121
+ self.metrics.names = self.names
122
+ self.metrics.plot = True
123
+ self.seen = 0
124
+ self.jdict = []
125
+ self.stats = []
126
+
127
+ def preprocess(self, batch):
128
+ batch["img"] = batch["img"].to(self.device, non_blocking=True)
129
+ batch["img"] = batch["img"].float() / 255
130
+ for k in ["batch_idx", "cls", "bboxes"]:
131
+ batch[k] = batch[k].to(self.device)
132
+
133
+ nb = len(batch["img"])
134
+ self.lb = [torch.cat([batch["cls"], batch["bboxes"]], dim=-1)[batch["batch_idx"] == i]
135
+ for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
136
+
137
+ return batch
138
+
139
+ def postprocess(self, preds):
140
+ preds = non_max_suppression(preds,
141
+ self.args.conf,
142
+ self.args.iou,
143
+ labels=self.lb,
144
+ multi_label=True,
145
+ agnostic=self.args.single_cls,
146
+ max_det=self.args.max_det)
147
+ return preds
148
+
149
+ def update_metrics(self, preds, batch):
150
+ # Metrics
151
+ for si, pred in enumerate(preds):
152
+ idx = batch["batch_idx"] == si
153
+ cls = batch["cls"][idx]
154
+ bbox = batch["bboxes"][idx]
155
+ nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
156
+ shape = batch["ori_shape"][si]
157
+ correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
158
+ self.seen += 1
159
+
160
+ if npr == 0:
161
+ if nl:
162
+ self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1)))
163
+ continue
164
+
165
+ # Predictions
166
+ if self.args.single_cls:
167
+ pred[:, 5] = 0
168
+ predn = pred.clone()
169
+ scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape,
170
+ ratio_pad=batch["ratio_pad"][si]) # native-space pred
171
+
172
+ # Evaluate
173
+ if nl:
174
+ height, width = batch["img"].shape[2:]
175
+ tbox = xywh2xyxy(bbox) * torch.tensor(
176
+ (width, height, width, height), device=self.device) # target boxes
177
+ scale_boxes(batch["img"][si].shape[1:], tbox, shape,
178
+ ratio_pad=batch["ratio_pad"][si]) # native-space labels
179
+ labelsn = torch.cat((cls, tbox), 1) # native-space labels
180
+ correct_bboxes = self._process_batch(predn, labelsn)
181
+ self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls)
182
+
183
+ # Save
184
+ if self.args.save_json:
185
+ self.pred_to_json(predn, batch["im_file"][si])
186
+
187
+ def _process_batch(self, detections, labels):
188
+ """
189
+ Return correct prediction matrix
190
+ Arguments:
191
+ detections (array[N, 6]), x1, y1, x2, y2, conf, class
192
+ labels (array[M, 5]), class, x1, y1, x2, y2
193
+ Returns:
194
+ correct (array[N, 10]), for 10 IoU levels
195
+ """
196
+ iou = box_iou(labels[:, 1:], detections[:, :4])
197
+ correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
198
+ correct_class = labels[:, 0:1] == detections[:, 5]
199
+ for i in range(len(self.iouv)):
200
+ x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
201
+ if x[0].shape[0]:
202
+ matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
203
+ 1).cpu().numpy() # [label, detect, iou]
204
+ if x[0].shape[0] > 1:
205
+ matches = matches[matches[:, 2].argsort()[::-1]]
206
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
207
+ # matches = matches[matches[:, 2].argsort()[::-1]]
208
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
209
+ correct[matches[:, 1].astype(int), i] = True
210
+ return torch.tensor(correct, dtype=torch.bool, device=detections.device)
211
+
212
+ def pred_to_json(self, predn, filename):
213
+ stem = Path(filename).stem
214
+ image_id = int(stem) if stem.isnumeric() else stem
215
+ box = xyxy2xywh(predn[:, :4]) # xywh
216
+ box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
217
+ for p, b in zip(predn.tolist(), box.tolist()):
218
+ self.jdict.append({
219
+ 'image_id': image_id,
220
+ 'category_id': self.class_map[int(p[5])],
221
+ 'bbox': [round(x, 3) for x in b],
222
+ 'score': round(p[4], 5)})
223
+
224
+ def get_stats(self):
225
+ stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
226
+ if len(stats) and stats[0].any():
227
+ self.metrics.process(*stats)
228
+ self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class
229
+ return self.metrics.results_dict
230
+
231
+ def print_results(self):
232
+ pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
233
+ self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
234
+ if self.nt_per_class.sum() == 0:
235
+ self.logger.warning(
236
+ f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
237
+
238
+ # Print results per class
239
+ if self.args.verbose and self.nc > 1 and len(self.stats):
240
+ for i, c in enumerate(self.metrics.ap_class_index):
241
+ self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
242
+
243
+
244
+ def eval_json(self, stats):
245
+ if self.args.save_json and self.is_coco and len(self.jdict):
246
+ anno_json = Path(self.data['path']) / 'annotations/instances_val2017.json' # annotations
247
+ pred_json = self.save_dir / "predictions.json" # predictions
248
+ self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
249
+ try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
250
+ from pycocotools.coco import COCO # noqa
251
+ from pycocotools.cocoeval import COCOeval # noqa
252
+ # for x in anno_json, pred_json:
253
+ # assert x.is_file(), f"{x} file not found"
254
+ anno = COCO(str(anno_json)) # init annotations api
255
+ pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
256
+ eval = COCOeval(anno, pred, 'bbox')
257
+ if self.is_coco:
258
+ eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
259
+ eval.evaluate()
260
+ eval.accumulate()
261
+ eval.summarize()
262
+ stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
263
+ except Exception as e:
264
+ self.logger.warning(f'pycocotools unable to run: {e}')
265
+ return stats
266
+
267
+
268
+ def parse_opt():
269
+ parser = argparse.ArgumentParser()
270
+ parser.add_argument('--ipu', action='store_true', help='flag for ryzen ai')
271
+ parser.add_argument('--provider_config', default='', type=str, help='provider config for ryzen ai')
272
+ parser.add_argument("-m", "--model", default="./yolov8m_qat.onnx", type=str, help='onnx_weight')
273
+ opt = parser.parse_args()
274
+ return opt
275
+
276
+
277
+ if __name__ == "__main__":
278
+ opt = parse_opt()
279
+ args = get_cfg(DEFAULT_CFG)
280
+ args.ipu = opt.ipu
281
+ args.onnx_weight = opt.model
282
+ args.provider_config = opt.provider_config
283
+ validator = DetectionValidator(args=args)
284
+ validator()
onnx_inference.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import onnxruntime
4
+ import numpy as np
5
+ import argparse
6
+ from utils import (
7
+ LoadImages,
8
+ non_max_suppression,
9
+ plot_images,
10
+ output_to_target,
11
+ )
12
+ import sys
13
+ import pathlib
14
+ CURRENT_DIR = pathlib.Path(__file__).parent
15
+ sys.path.append(str(CURRENT_DIR))
16
+
17
+ def preprocess(img):
18
+ img = torch.from_numpy(img)
19
+ img = img.float() # uint8 to fp16/32
20
+ img /= 255 # 0 - 255 to 0.0 - 1.0
21
+ return img
22
+
23
+
24
+ class DFL(nn.Module):
25
+ # Integral module of Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
26
+ def __init__(self, c1=16):
27
+ super().__init__()
28
+ self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
29
+ x = torch.arange(c1, dtype=torch.float)
30
+ self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
31
+ self.c1 = c1
32
+
33
+ def forward(self, x):
34
+ b, c, a = x.shape # batch, channels, anchors
35
+ return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(
36
+ b, 4, a
37
+ )
38
+
39
+
40
+ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
41
+ """Transform distance(ltrb) to box(xywh or xyxy)."""
42
+ lt, rb = torch.split(distance, 2, dim)
43
+ x1y1 = anchor_points - lt
44
+ x2y2 = anchor_points + rb
45
+ if xywh:
46
+ c_xy = (x1y1 + x2y2) / 2
47
+ wh = x2y2 - x1y1
48
+ return torch.cat((c_xy, wh), dim) # xywh bbox
49
+ return torch.cat((x1y1, x2y2), dim) # xyxy bbox
50
+
51
+
52
+ def post_process(x):
53
+ dfl = DFL(16)
54
+ anchors = torch.tensor(
55
+ np.load(
56
+ "./anchors.npy",
57
+ allow_pickle=True,
58
+ )
59
+ )
60
+ strides = torch.tensor(
61
+ np.load(
62
+ "./strides.npy",
63
+ allow_pickle=True,
64
+ )
65
+ )
66
+ box, cls = torch.cat([xi.view(x[0].shape[0], 144, -1) for xi in x], 2).split(
67
+ (16 * 4, 80), 1
68
+ )
69
+ dbox = dist2bbox(dfl(box), anchors.unsqueeze(0), xywh=True, dim=1) * strides
70
+ y = torch.cat((dbox, cls.sigmoid()), 1)
71
+ return y, x
72
+
73
+
74
+ def make_parser():
75
+ parser = argparse.ArgumentParser("onnxruntime inference sample")
76
+ parser.add_argument(
77
+ "-m",
78
+ "--model",
79
+ type=str,
80
+ default="./yolov8m_qat.onnx",
81
+ help="input your onnx model.",
82
+ )
83
+ parser.add_argument(
84
+ "-i",
85
+ "--image_path",
86
+ type=str,
87
+ default='./demo.jpg',
88
+ help="path to your input image.",
89
+ )
90
+ parser.add_argument(
91
+ "-o",
92
+ "--output_path",
93
+ type=str,
94
+ default='./demo_infer.jpg',
95
+ help="path to your output directory.",
96
+ )
97
+ parser.add_argument(
98
+ "--ipu", action='store_true', help='flag for ryzen ai'
99
+ )
100
+ parser.add_argument(
101
+ "--provider_config", default='', type=str, help='provider config for ryzen ai'
102
+ )
103
+ return parser
104
+
105
+ classnames = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
106
+ 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
107
+ 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
108
+ 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
109
+ 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
110
+ 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
111
+ 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
112
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
113
+ 'hair drier', 'toothbrush']
114
+ names = {k: classnames[k] for k in range(80)}
115
+ imgsz = [640, 640]
116
+
117
+
118
+ if __name__ == '__main__':
119
+ args = make_parser().parse_args()
120
+ source = args.image_path
121
+ dataset = LoadImages(
122
+ source, imgsz=imgsz, stride=32, auto=False, transforms=None, vid_stride=1
123
+ )
124
+ onnx_weight = args.model
125
+ if args.ipu:
126
+ providers = ["VitisAIExecutionProvider"]
127
+ provider_options = [{"config_file": args.provider_config}]
128
+ onnx_model = onnxruntime.InferenceSession(onnx_weight, providers=providers, provider_options=provider_options)
129
+ else:
130
+ onnx_model = onnxruntime.InferenceSession(onnx_weight)
131
+ for batch in dataset:
132
+ path, im, im0s, vid_cap, s = batch
133
+ im = preprocess(im)
134
+ if len(im.shape) == 3:
135
+ im = im[None]
136
+ outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: im.cpu().numpy()})
137
+ outputs = [torch.tensor(item) for item in outputs]
138
+ preds = post_process(outputs)
139
+ preds = non_max_suppression(
140
+ preds, 0.25, 0.7, agnostic=False, max_det=300, classes=None
141
+ )
142
+ plot_images(
143
+ im,
144
+ *output_to_target(preds, max_det=15),
145
+ source,
146
+ fname=args.output_path,
147
+ names=names,
148
+ )
requirements.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics requirements
2
+ # Usage: pip install -r requirements.txt
3
+
4
+ # Base ----------------------------------------
5
+ matplotlib>=3.2.2
6
+ numpy>=1.18.5
7
+ opencv-python>=4.6.0
8
+ Pillow>=7.1.2
9
+ PyYAML>=5.3.1
10
+ requests>=2.23.0
11
+ scipy>=1.4.1
12
+ torch>=1.7.0
13
+ torchvision>=0.8.1
14
+ tqdm>=4.64.0
15
+
16
+ # Logging -------------------------------------
17
+ tensorboard>=2.4.1
18
+ # clearml
19
+ # comet
20
+
21
+ # Plotting ------------------------------------
22
+ pandas>=1.1.4
23
+ seaborn>=0.11.0
24
+
25
+ # Export --------------------------------------
26
+ # onnxruntime
27
+ # coremltools>=6.0 # CoreML export
28
+ onnx>=1.12.0 # ONNX export
29
+ # onnx-simplifier>=0.4.1 # ONNX simplifier
30
+ # nvidia-pyindex # TensorRT export
31
+ # nvidia-tensorrt # TensorRT export
32
+ # scikit-learn==0.19.2 # CoreML quantization
33
+ # tensorflow>=2.4.1 # TF exports (-cpu, -aarch64, -macos)
34
+ # tensorflowjs>=3.9.0 # TF.js export
35
+ # openvino-dev>=2022.3 # OpenVINO export
36
+
37
+ # Extras --------------------------------------
38
+ ipython # interactive notebook
39
+ psutil # system utilization
40
+ thop>=0.1.1 # FLOPs computation
41
+ # albumentations>=1.0.3
42
+ pycocotools>=2.0.6 # COCO mAP
43
+ # roboflow
strides.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21fa445f4e0732b0213026bb865551d7986579ea19b21d81763c218be0368b9e
3
+ size 33728
utils.py ADDED
@@ -0,0 +1,2140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import os
3
+ import contextlib
4
+ import torch
5
+ import torch.nn as nn
6
+ from PIL import Image, ImageDraw, ImageFont, ExifTags
7
+ from PIL import __version__ as pil_version
8
+ from multiprocessing.pool import ThreadPool
9
+ import numpy as np
10
+ from itertools import repeat
11
+ import glob
12
+ import cv2
13
+ import tempfile
14
+ import hashlib
15
+ from pathlib import Path
16
+ import time
17
+ import torchvision
18
+ import math
19
+ import re
20
+ from typing import List, Union, Dict
21
+ import pkg_resources as pkg
22
+ from types import SimpleNamespace
23
+ from torch.utils.data import Dataset, DataLoader
24
+ from tqdm import tqdm
25
+ import random
26
+ import yaml
27
+ import logging.config
28
+ import sys
29
+ import pathlib
30
+ CURRENT_DIR = pathlib.Path(__file__).parent
31
+ sys.path.append(str(CURRENT_DIR))
32
+
33
+ LOGGING_NAME = 'ultralytics'
34
+ LOGGER = logging.getLogger(LOGGING_NAME)
35
+ for fn in LOGGER.info, LOGGER.warning:
36
+ setattr(LOGGER, fn.__name__, lambda x: fn(x))
37
+ IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
38
+ VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
39
+ TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
40
+ NUM_THREADS = min(8, os.cpu_count())
41
+ PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
42
+ _formats = ["xyxy", "xywh", "ltwh"]
43
+ CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'}
44
+ CFG_FRACTION_KEYS = {
45
+ 'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'fl_gamma',
46
+ 'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic',
47
+ 'mixup', 'copy_paste', 'conf', 'iou'}
48
+ CFG_INT_KEYS = {
49
+ 'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
50
+ 'line_thickness', 'workspace', 'nbs'}
51
+ CFG_BOOL_KEYS = {
52
+ 'save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr',
53
+ 'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf',
54
+ 'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
55
+ 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader'}
56
+ # Get orientation exif tag
57
+ for orientation in ExifTags.TAGS.keys():
58
+ if ExifTags.TAGS[orientation] == 'Orientation':
59
+ break
60
+
61
+ def segments2boxes(segments):
62
+ """
63
+ It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
64
+
65
+ Args:
66
+ segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
67
+
68
+ Returns:
69
+ (np.ndarray): the xywh coordinates of the bounding boxes.
70
+ """
71
+ boxes = []
72
+ for s in segments:
73
+ x, y = s.T # segment xy
74
+ boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
75
+ return xyxy2xywh(np.array(boxes)) # cls, xywh
76
+
77
+
78
+ def check_version(
79
+ current: str = "0.0.0",
80
+ minimum: str = "0.0.0",
81
+ name: str = "version ",
82
+ pinned: bool = False,
83
+ hard: bool = False,
84
+ verbose: bool = False,
85
+ ) -> bool:
86
+ """
87
+ Check current version against the required minimum version.
88
+
89
+ Args:
90
+ current (str): Current version.
91
+ minimum (str): Required minimum version.
92
+ name (str): Name to be used in warning message.
93
+ pinned (bool): If True, versions must match exactly. If False, minimum version must be satisfied.
94
+ hard (bool): If True, raise an AssertionError if the minimum version is not met.
95
+ verbose (bool): If True, print warning message if minimum version is not met.
96
+
97
+ Returns:
98
+ bool: True if minimum version is met, False otherwise.
99
+ """
100
+ current, minimum = (pkg.parse_version(x) for x in (current, minimum))
101
+ result = (current == minimum) if pinned else (current >= minimum) # bool
102
+ warning_message = f"WARNING ⚠️ {name}{minimum} is required by YOLOv8, but {name}{current} is currently installed"
103
+ if verbose and not result:
104
+ LOGGER.warning(warning_message)
105
+ return result
106
+
107
+
108
+ TORCH_1_9 = check_version(torch.__version__, '1.9.0')
109
+
110
+
111
+ def smart_inference_mode():
112
+ # Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
113
+ def decorate(fn):
114
+ return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
115
+
116
+ return decorate
117
+
118
+
119
+ def box_iou(box1, box2, eps=1e-7):
120
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
121
+ """
122
+ Return intersection-over-union (Jaccard index) of boxes.
123
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
124
+ Arguments:
125
+ box1 (Tensor[N, 4])
126
+ box2 (Tensor[M, 4])
127
+ Returns:
128
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
129
+ IoU values for every element in boxes1 and boxes2
130
+ """
131
+
132
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
133
+ (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
134
+ inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
135
+
136
+ # IoU = inter / (area1 + area2 - inter)
137
+ return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
138
+
139
+
140
+ class LoadImages:
141
+ # YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
142
+ def __init__(
143
+ self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1
144
+ ):
145
+ # *.txt file with img/vid/dir on each line
146
+ if isinstance(path, str) and Path(path).suffix == ".txt":
147
+ path = Path(path).read_text().rsplit()
148
+ files = []
149
+ for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
150
+ p = str(Path(p).resolve())
151
+ if "*" in p:
152
+ files.extend(sorted(glob.glob(p, recursive=True))) # glob
153
+ elif os.path.isdir(p):
154
+ files.extend(sorted(glob.glob(os.path.join(p, "*.*")))) # dir
155
+ elif os.path.isfile(p):
156
+ files.append(p) # files
157
+ else:
158
+ raise FileNotFoundError(f"{p} does not exist")
159
+ # include image suffixes
160
+ images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
161
+ videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
162
+ ni, nv = len(images), len(videos)
163
+
164
+ self.imgsz = imgsz
165
+ self.stride = stride
166
+ self.files = images + videos
167
+ self.nf = ni + nv # number of files
168
+ self.video_flag = [False] * ni + [True] * nv
169
+ self.mode = "image"
170
+ self.auto = auto
171
+ self.transforms = transforms # optional
172
+ self.vid_stride = vid_stride # video frame-rate stride
173
+ self.bs = 1
174
+ if any(videos):
175
+ self.orientation = None # rotation degrees
176
+ self._new_video(videos[0]) # new video
177
+ else:
178
+ self.cap = None
179
+ if self.nf == 0:
180
+ raise FileNotFoundError(
181
+ f"No images or videos found in {p}. "
182
+ f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
183
+ )
184
+
185
+ def __iter__(self):
186
+ self.count = 0
187
+ return self
188
+
189
+ def __next__(self):
190
+ if self.count == self.nf:
191
+ raise StopIteration
192
+ path = self.files[self.count]
193
+
194
+ if self.video_flag[self.count]:
195
+ # Read video
196
+ self.mode = "video"
197
+ for _ in range(self.vid_stride):
198
+ self.cap.grab()
199
+ success, im0 = self.cap.retrieve()
200
+ while not success:
201
+ self.count += 1
202
+ self.cap.release()
203
+ if self.count == self.nf: # last video
204
+ raise StopIteration
205
+ path = self.files[self.count]
206
+ self._new_video(path)
207
+ success, im0 = self.cap.read()
208
+
209
+ self.frame += 1
210
+ s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: "
211
+
212
+ else:
213
+ # Read image
214
+ self.count += 1
215
+ im0 = cv2.imread(path) # BGR
216
+ if im0 is None:
217
+ raise FileNotFoundError(f"Image Not Found {path}")
218
+ s = f"image {self.count}/{self.nf} {path}: "
219
+
220
+ if self.transforms:
221
+ im = self.transforms(im0) # transforms
222
+ else:
223
+ im = LetterBox(self.imgsz, self.auto, stride=self.stride)(image=im0)
224
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
225
+ im = np.ascontiguousarray(im) # contiguous
226
+
227
+ return path, im, im0, self.cap, s
228
+
229
+ def _new_video(self, path):
230
+ # Create a new video capture object
231
+ self.frame = 0
232
+ self.cap = cv2.VideoCapture(path)
233
+ self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
234
+ if hasattr(cv2, "CAP_PROP_ORIENTATION_META"): # cv2<4.6.0 compatibility
235
+ self.orientation = int(
236
+ self.cap.get(cv2.CAP_PROP_ORIENTATION_META)
237
+ ) # rotation degrees
238
+ # Disable auto-orientation due to known issues in https://github.com/ultralytics/yolov5/issues/8493
239
+ # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)
240
+
241
+ def _cv2_rotate(self, im):
242
+ # Rotate a cv2 video manually
243
+ if self.orientation == 0:
244
+ return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
245
+ elif self.orientation == 180:
246
+ return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
247
+ elif self.orientation == 90:
248
+ return cv2.rotate(im, cv2.ROTATE_180)
249
+ return im
250
+
251
+ def __len__(self):
252
+ return self.nf # number of files
253
+
254
+
255
+ class LetterBox:
256
+ """Resize image and padding for detection, instance segmentation, pose"""
257
+
258
+ def __init__(
259
+ self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32
260
+ ):
261
+ self.new_shape = new_shape
262
+ self.auto = auto
263
+ self.scaleFill = scaleFill
264
+ self.scaleup = scaleup
265
+ self.stride = stride
266
+
267
+ def __call__(self, labels=None, image=None):
268
+ if labels is None:
269
+ labels = {}
270
+ img = labels.get("img") if image is None else image
271
+ shape = img.shape[:2] # current shape [height, width]
272
+ new_shape = labels.pop("rect_shape", self.new_shape)
273
+ if isinstance(new_shape, int):
274
+ new_shape = (new_shape, new_shape)
275
+
276
+ # Scale ratio (new / old)
277
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
278
+ # only scale down, do not scale up (for better val mAP)
279
+ if not self.scaleup:
280
+ r = min(r, 1.0)
281
+
282
+ # Compute padding
283
+ ratio = r, r # width, height ratios
284
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
285
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
286
+ if self.auto: # minimum rectangle
287
+ dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
288
+ elif self.scaleFill: # stretch
289
+ dw, dh = 0.0, 0.0
290
+ new_unpad = (new_shape[1], new_shape[0])
291
+ ratio = (
292
+ new_shape[1] / shape[1],
293
+ new_shape[0] / shape[0],
294
+ ) # width, height ratios
295
+
296
+ dw /= 2 # divide padding into 2 sides
297
+ dh /= 2
298
+ if labels.get("ratio_pad"):
299
+ labels["ratio_pad"] = (labels["ratio_pad"], (dw, dh)) # for evaluation
300
+
301
+ if shape[::-1] != new_unpad: # resize
302
+ img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
303
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
304
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
305
+ img = cv2.copyMakeBorder(
306
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
307
+ ) # add border
308
+
309
+ if len(labels):
310
+ labels = self._update_labels(labels, ratio, dw, dh)
311
+ labels["img"] = img
312
+ labels["resized_shape"] = new_shape
313
+ return labels
314
+ else:
315
+ return img
316
+
317
+ def _update_labels(self, labels, ratio, padw, padh):
318
+ """Update labels"""
319
+ labels["instances"].convert_bbox(format="xyxy")
320
+ labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
321
+ labels["instances"].scale(*ratio)
322
+ labels["instances"].add_padding(padw, padh)
323
+ return labels
324
+
325
+
326
+ class Annotator:
327
+ # YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
328
+ def __init__(
329
+ self,
330
+ im,
331
+ line_width=None,
332
+ font_size=None,
333
+ font="Arial.ttf",
334
+ pil=False,
335
+ example="abc",
336
+ ):
337
+ assert (
338
+ im.data.contiguous
339
+ ), "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images."
340
+ # non-latin labels, i.e. asian, arabic, cyrillic
341
+ non_ascii = not is_ascii(example)
342
+ self.pil = pil or non_ascii
343
+ if self.pil: # use PIL
344
+ self.pil_9_2_0_check = check_version(
345
+ pil_version, "9.2.0"
346
+ ) # deprecation check
347
+ self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
348
+ self.draw = ImageDraw.Draw(self.im)
349
+ self.font = ImageFont.load_default()
350
+ else: # use cv2
351
+ self.im = im
352
+ self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
353
+
354
+ def box_label(
355
+ self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255)
356
+ ):
357
+ # Add one xyxy box to image with label
358
+ if isinstance(box, torch.Tensor):
359
+ box = box.tolist()
360
+ if self.pil or not is_ascii(label):
361
+ self.draw.rectangle(box, width=self.lw, outline=color) # box
362
+ if label:
363
+ if self.pil_9_2_0_check:
364
+ _, _, w, h = self.font.getbbox(label) # text width, height (New)
365
+ else:
366
+ w, h = self.font.getsize(
367
+ label
368
+ ) # text width, height (Old, deprecated in 9.2.0)
369
+ outside = box[1] - h >= 0 # label fits outside box
370
+ self.draw.rectangle(
371
+ (
372
+ box[0],
373
+ box[1] - h if outside else box[1],
374
+ box[0] + w + 1,
375
+ box[1] + 1 if outside else box[1] + h + 1,
376
+ ),
377
+ fill=color,
378
+ )
379
+ # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
380
+ self.draw.text(
381
+ (box[0], box[1] - h if outside else box[1]),
382
+ label,
383
+ fill=txt_color,
384
+ font=self.font,
385
+ )
386
+ else: # cv2
387
+ p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
388
+ cv2.rectangle(
389
+ self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA
390
+ )
391
+ if label:
392
+ tf = max(self.lw - 1, 1) # font thickness
393
+ # text width, height
394
+ w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0]
395
+ outside = p1[1] - h >= 3
396
+ p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
397
+ cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
398
+ cv2.putText(
399
+ self.im,
400
+ label,
401
+ (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
402
+ 0,
403
+ self.lw / 3,
404
+ txt_color,
405
+ thickness=tf,
406
+ lineType=cv2.LINE_AA,
407
+ )
408
+
409
+ def rectangle(self, xy, fill=None, outline=None, width=1):
410
+ # Add rectangle to image (PIL-only)
411
+ self.draw.rectangle(xy, fill, outline, width)
412
+
413
+ def text(self, xy, text, txt_color=(255, 255, 255), anchor="top"):
414
+ # Add text to image (PIL-only)
415
+ if anchor == "bottom": # start y from font bottom
416
+ w, h = self.font.getsize(text) # text width, height
417
+ xy[1] += 1 - h
418
+ self.draw.text(xy, text, fill=txt_color, font=self.font)
419
+
420
+ def fromarray(self, im):
421
+ # Update self.im from a numpy array
422
+ self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
423
+ self.draw = ImageDraw.Draw(self.im)
424
+
425
+ def result(self):
426
+ # Return annotated image as array
427
+ return np.asarray(self.im)
428
+
429
+
430
+ def non_max_suppression(
431
+ prediction,
432
+ conf_thres=0.25,
433
+ iou_thres=0.45,
434
+ classes=None,
435
+ agnostic=False,
436
+ multi_label=False,
437
+ labels=(),
438
+ max_det=300,
439
+ nm=0, # number of masks
440
+ ):
441
+ # Checks
442
+ assert (
443
+ 0 <= conf_thres <= 1
444
+ ), f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
445
+ assert (
446
+ 0 <= iou_thres <= 1
447
+ ), f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
448
+ # YOLOv8 model in validation model, output = (inference_out, loss_out)
449
+ if isinstance(prediction, (list, tuple)):
450
+ prediction = prediction[0] # select only inference output
451
+ device = prediction.device
452
+ mps = "mps" in device.type # Apple MPS
453
+ if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
454
+ prediction = prediction.cpu()
455
+ bs = prediction.shape[0] # batch size
456
+ nc = prediction.shape[1] - nm - 4 # number of classes
457
+ mi = 4 + nc # mask start index
458
+ xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
459
+
460
+ # Settings
461
+ # min_wh = 2 # (pixels) minimum box width and height
462
+ max_wh = 7680 # (pixels) maximum box width and height
463
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
464
+ time_limit = 0.5 + 0.05 * bs # seconds to quit after
465
+ redundant = True # require redundant detections
466
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
467
+ merge = False # use merge-NMS
468
+
469
+ t = time.time()
470
+ output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
471
+ for xi, x in enumerate(prediction): # image index, image inference
472
+ # Apply constraints
473
+ # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
474
+ x = x.transpose(0, -1)[xc[xi]] # confidence
475
+
476
+ # Cat apriori labels if autolabelling
477
+ if labels and len(labels[xi]):
478
+ lb = labels[xi]
479
+ v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
480
+ v[:, :4] = lb[:, 1:5] # box
481
+ v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
482
+ x = torch.cat((x, v), 0)
483
+
484
+ # If none remain process next image
485
+ if not x.shape[0]:
486
+ continue
487
+
488
+ # Detections matrix nx6 (xyxy, conf, cls)
489
+ box, cls, mask = x.split((4, nc, nm), 1)
490
+ # center_x, center_y, width, height) to (x1, y1, x2, y2)
491
+ box = xywh2xyxy(box)
492
+ if multi_label:
493
+ i, j = (cls > conf_thres).nonzero(as_tuple=False).T
494
+ x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
495
+ else: # best class only
496
+ conf, j = cls.max(1, keepdim=True)
497
+ x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
498
+
499
+ # Filter by class
500
+ if classes is not None:
501
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
502
+
503
+ # Check shape
504
+ n = x.shape[0] # number of boxes
505
+ if not n: # no boxes
506
+ continue
507
+ # sort by confidence and remove excess boxes
508
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]]
509
+
510
+ # Batched NMS
511
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
512
+ # boxes (offset by class), scores
513
+ boxes, scores = x[:, :4] + c, x[:, 4]
514
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
515
+ i = i[:max_det] # limit detections
516
+ if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
517
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
518
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
519
+ weights = iou * scores[None] # box weights
520
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(
521
+ 1, keepdim=True
522
+ ) # merged boxes
523
+ if redundant:
524
+ i = i[iou.sum(1) > 1] # require redundancy
525
+
526
+ output[xi] = x[i]
527
+ if mps:
528
+ output[xi] = output[xi].to(device)
529
+ if (time.time() - t) > time_limit:
530
+ LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
531
+ break # time limit exceeded
532
+
533
+ return output
534
+
535
+
536
+ class Colors:
537
+ # Ultralytics color palette https://ultralytics.com/
538
+ def __init__(self):
539
+ # hex = matplotlib.colors.TABLEAU_COLORS.values()
540
+ hexs = (
541
+ "FF3838",
542
+ "FF9D97",
543
+ "FF701F",
544
+ "FFB21D",
545
+ "CFD231",
546
+ "48F90A",
547
+ "92CC17",
548
+ "3DDB86",
549
+ "1A9334",
550
+ "00D4BB",
551
+ "2C99A8",
552
+ "00C2FF",
553
+ "344593",
554
+ "6473FF",
555
+ "0018EC",
556
+ "8438FF",
557
+ "520085",
558
+ "CB38FF",
559
+ "FF95C8",
560
+ "FF37C7",
561
+ )
562
+ self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
563
+ self.n = len(self.palette)
564
+
565
+ def __call__(self, i, bgr=False):
566
+ c = self.palette[int(i) % self.n]
567
+ return (c[2], c[1], c[0]) if bgr else c
568
+
569
+ @staticmethod
570
+ def hex2rgb(h): # rgb order (PIL)
571
+ return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
572
+
573
+
574
+ colors = Colors() # create instance for 'from utils.plots import colors'
575
+
576
+
577
+ def threaded(func):
578
+ # Multi-threads a target function and returns thread. Usage: @threaded decorator
579
+ def wrapper(*args, **kwargs):
580
+ thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
581
+ thread.start()
582
+ return thread
583
+
584
+ return wrapper
585
+
586
+
587
+ def plot_images(
588
+ images,
589
+ batch_idx,
590
+ cls,
591
+ bboxes,
592
+ masks=np.zeros(0, dtype=np.uint8),
593
+ paths=None,
594
+ fname="images.jpg",
595
+ names=None,
596
+ ):
597
+ # Plot image grid with labels
598
+ if isinstance(images, torch.Tensor):
599
+ images = images.cpu().float().numpy()
600
+ if isinstance(cls, torch.Tensor):
601
+ cls = cls.cpu().numpy()
602
+ if isinstance(bboxes, torch.Tensor):
603
+ bboxes = bboxes.cpu().numpy()
604
+ if isinstance(masks, torch.Tensor):
605
+ masks = masks.cpu().numpy().astype(int)
606
+ if isinstance(batch_idx, torch.Tensor):
607
+ batch_idx = batch_idx.cpu().numpy()
608
+
609
+ max_size = 1920 # max image size
610
+ max_subplots = 16 # max image subplots, i.e. 4x4
611
+ bs, _, h, w = images.shape # batch size, _, height, width
612
+ bs = min(bs, max_subplots) # limit plot images
613
+ ns = np.ceil(bs**0.5) # number of subplots (square)
614
+ if np.max(images[0]) <= 1:
615
+ images *= 255 # de-normalise (optional)
616
+
617
+ # Build Image
618
+ mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
619
+ for i, im in enumerate(images):
620
+ if i == max_subplots: # if last batch has fewer images than we expect
621
+ break
622
+ x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
623
+ im = im.transpose(1, 2, 0)
624
+ mosaic[y : y + h, x : x + w, :] = im
625
+
626
+ # Resize (optional)
627
+ scale = max_size / ns / max(h, w)
628
+ if scale < 1:
629
+ h = math.ceil(scale * h)
630
+ w = math.ceil(scale * w)
631
+ mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
632
+
633
+ # Annotate
634
+ fs = int((h + w) * ns * 0.01) # font size
635
+ annotator = Annotator(
636
+ mosaic, line_width=2, font_size=fs, pil=True, example=names
637
+ )
638
+ for i in range(i + 1):
639
+ x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
640
+ annotator.rectangle(
641
+ [x, y, x + w, y + h], None, (255, 255, 255), width=2
642
+ ) # borders
643
+ if paths:
644
+ annotator.text(
645
+ # filenames
646
+ (x + 5, y + 5 + h),
647
+ text=Path(paths[i]).name[:40],
648
+ txt_color=(220, 220, 220),
649
+ )
650
+ if len(cls) > 0:
651
+ idx = batch_idx == i
652
+
653
+ boxes = xywh2xyxy(bboxes[idx, :4]).T
654
+ classes = cls[idx].astype("int")
655
+ labels = bboxes.shape[1] == 4 # labels if no conf column
656
+ # check for confidence presence (label vs pred)
657
+ conf = None if labels else bboxes[idx, 4]
658
+
659
+ if boxes.shape[1]:
660
+ if boxes.max() <= 1.01: # if normalized with tolerance 0.01
661
+ boxes[[0, 2]] *= w # scale to pixels
662
+ boxes[[1, 3]] *= h
663
+ elif scale < 1: # absolute coords need scale if image scales
664
+ boxes *= scale
665
+ boxes[[0, 2]] += x
666
+ boxes[[1, 3]] += y
667
+ for j, box in enumerate(boxes.T.tolist()):
668
+ c = classes[j]
669
+ color = colors(c)
670
+ c = names[c] if names else c
671
+ if labels or conf[j] > 0.25: # 0.25 conf thresh
672
+ label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
673
+ annotator.box_label(box, label, color=color)
674
+ annotator.im.save(fname) # save
675
+
676
+
677
+ def output_to_target(output, max_det=300):
678
+ # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
679
+ targets = []
680
+ for i, o in enumerate(output):
681
+ box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
682
+ j = torch.full((conf.shape[0], 1), i)
683
+ targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
684
+ targets = torch.cat(targets, 0).numpy()
685
+ return targets[:, 0], targets[:, 1], targets[:, 2:]
686
+
687
+
688
+ def is_ascii(s=""):
689
+ # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
690
+ s = str(s) # convert list, tuple, None, etc. to str
691
+ return len(s.encode().decode("ascii", "ignore")) == len(s)
692
+
693
+
694
+ def xyxy2xywh(x):
695
+ """
696
+ Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format.
697
+
698
+ Args:
699
+ x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
700
+ Returns:
701
+ y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
702
+ """
703
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
704
+ y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
705
+ y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
706
+ y[..., 2] = x[..., 2] - x[..., 0] # width
707
+ y[..., 3] = x[..., 3] - x[..., 1] # height
708
+ return y
709
+
710
+
711
+ def xywh2xyxy(x):
712
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
713
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
714
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
715
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
716
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
717
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
718
+ return y
719
+
720
+
721
+ def check_det_dataset(dataset, autodownload=True):
722
+ # Download, check and/or unzip dataset if not found locally
723
+ data = dataset
724
+ # Download (optional)
725
+ extract_dir = ''
726
+
727
+ # Read yaml (optional)
728
+ if isinstance(data, (str, Path)):
729
+ data = yaml_load(data, append_filename=True) # dictionary
730
+
731
+ # Checks
732
+ if isinstance(data['names'], (list, tuple)): # old array format
733
+ data['names'] = dict(enumerate(data['names'])) # convert to dict
734
+ data['nc'] = len(data['names'])
735
+
736
+ # Resolve paths
737
+ path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
738
+
739
+ DATASETS_DIR = os.path.abspath('.')
740
+ if not path.is_absolute():
741
+ path = (DATASETS_DIR / path).resolve()
742
+ data['path'] = path # download scripts
743
+ for k in 'train', 'val', 'test':
744
+ if data.get(k): # prepend path
745
+ if isinstance(data[k], str):
746
+ x = (path / data[k]).resolve()
747
+ if not x.exists() and data[k].startswith('../'):
748
+ x = (path / data[k][3:]).resolve()
749
+ data[k] = str(x)
750
+ else:
751
+ data[k] = [str((path / x).resolve()) for x in data[k]]
752
+
753
+ # Parse yaml
754
+ train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
755
+ if val:
756
+ val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
757
+ if not all(x.exists() for x in val):
758
+ msg = f"\nDataset '{dataset}' not found ⚠️, missing paths %s" % [str(x) for x in val if not x.exists()]
759
+ if s and autodownload:
760
+ LOGGER.warning(msg)
761
+ else:
762
+ raise FileNotFoundError(msg)
763
+ t = time.time()
764
+ if s.startswith('bash '): # bash script
765
+ LOGGER.info(f'Running {s} ...')
766
+ r = os.system(s)
767
+ else: # python script
768
+ r = exec(s, {'yaml': data}) # return None
769
+ dt = f'({round(time.time() - t, 1)}s)'
770
+ s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
771
+ LOGGER.info(f"Dataset download {s}\n")
772
+
773
+ return data # dictionary
774
+
775
+
776
+ def yaml_load(file='data.yaml', append_filename=False):
777
+ """
778
+ Load YAML data from a file.
779
+
780
+ Args:
781
+ file (str, optional): File name. Default is 'data.yaml'.
782
+ append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
783
+
784
+ Returns:
785
+ dict: YAML data and file name.
786
+ """
787
+ with open(file, errors='ignore', encoding='utf-8') as f:
788
+ # Add YAML filename to dict and return
789
+ s = f.read() # string
790
+ if not s.isprintable(): # remove special characters
791
+ s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
792
+ return {**yaml.safe_load(s), 'yaml_file': str(file)} if append_filename else yaml.safe_load(s)
793
+
794
+
795
+ class IterableSimpleNamespace(SimpleNamespace):
796
+ """
797
+ Iterable SimpleNamespace class to allow SimpleNamespace to be used with dict() and in for loops
798
+ """
799
+
800
+ def __iter__(self):
801
+ return iter(vars(self).items())
802
+
803
+ def __str__(self):
804
+ return '\n'.join(f"{k}={v}" for k, v in vars(self).items())
805
+
806
+ def get(self, key, default=None):
807
+ return getattr(self, key, default)
808
+
809
+
810
+ def colorstr(*input):
811
+ # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
812
+ *args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
813
+ colors = {
814
+ "black": "\033[30m", # basic colors
815
+ "red": "\033[31m",
816
+ "green": "\033[32m",
817
+ "yellow": "\033[33m",
818
+ "blue": "\033[34m",
819
+ "magenta": "\033[35m",
820
+ "cyan": "\033[36m",
821
+ "white": "\033[37m",
822
+ "bright_black": "\033[90m", # bright colors
823
+ "bright_red": "\033[91m",
824
+ "bright_green": "\033[92m",
825
+ "bright_yellow": "\033[93m",
826
+ "bright_blue": "\033[94m",
827
+ "bright_magenta": "\033[95m",
828
+ "bright_cyan": "\033[96m",
829
+ "bright_white": "\033[97m",
830
+ "end": "\033[0m", # misc
831
+ "bold": "\033[1m",
832
+ "underline": "\033[4m"}
833
+ return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
834
+
835
+
836
+ def seed_worker(worker_id):
837
+ # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
838
+ worker_seed = torch.initial_seed() % 2 ** 32
839
+ np.random.seed(worker_seed)
840
+ random.seed(worker_seed)
841
+
842
+
843
+ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode="train"):
844
+ assert mode in ["train", "val"]
845
+ shuffle = mode == "train"
846
+ if cfg.rect and shuffle:
847
+ LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
848
+ shuffle = False
849
+ dataset = YOLODataset(
850
+ img_path=img_path,
851
+ imgsz=cfg.imgsz,
852
+ batch_size=batch,
853
+ augment=mode == "train", # augmentation
854
+ hyp=cfg,
855
+ rect=cfg.rect or rect, # rectangular batches
856
+ cache=cfg.cache or None,
857
+ single_cls=cfg.single_cls or False,
858
+ stride=int(stride),
859
+ pad=0.0 if mode == "train" else 0.5,
860
+ prefix=colorstr(f"{mode}: "),
861
+ use_segments=cfg.task == "segment",
862
+ use_keypoints=cfg.task == "keypoint",
863
+ names=names)
864
+
865
+ batch = min(batch, len(dataset))
866
+ nd = torch.cuda.device_count() # number of CUDA devices
867
+ workers = cfg.workers if mode == "train" else cfg.workers * 2
868
+ nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
869
+
870
+ if rank == -1:
871
+ sampler = None
872
+ if cfg.image_weights or cfg.close_mosaic:
873
+ loader = DataLoader
874
+ generator = torch.Generator()
875
+ generator.manual_seed(6148914691236517205)
876
+ return loader(dataset=dataset,
877
+ batch_size=batch,
878
+ shuffle=shuffle and sampler is None,
879
+ num_workers=nw,
880
+ sampler=sampler,
881
+ pin_memory=PIN_MEMORY,
882
+ collate_fn=getattr(dataset, "collate_fn", None),
883
+ worker_init_fn=seed_worker,
884
+ generator=generator), dataset
885
+
886
+
887
+ class BaseDataset(Dataset):
888
+ """Base Dataset.
889
+ Args:
890
+ img_path (str): image path.
891
+ pipeline (dict): a dict of image transforms.
892
+ label_path (str): label path, this can also be an ann_file or other custom label path.
893
+ """
894
+
895
+ def __init__(
896
+ self,
897
+ img_path,
898
+ imgsz=640,
899
+ cache=False,
900
+ augment=True,
901
+ hyp=None,
902
+ prefix="",
903
+ rect=False,
904
+ batch_size=None,
905
+ stride=32,
906
+ pad=0.5,
907
+ single_cls=False,
908
+ ):
909
+ super().__init__()
910
+ self.img_path = img_path
911
+ self.imgsz = imgsz
912
+ self.augment = augment
913
+ self.single_cls = single_cls
914
+ self.prefix = prefix
915
+ self.im_files = self.get_img_files(self.img_path)
916
+ self.labels = self.get_labels()
917
+ self.ni = len(self.labels)
918
+
919
+ # rect stuff
920
+ self.rect = rect
921
+ self.batch_size = batch_size
922
+ self.stride = stride
923
+ self.pad = pad
924
+ if self.rect:
925
+ assert self.batch_size is not None
926
+ self.set_rectangle()
927
+
928
+ # cache stuff
929
+ self.ims = [None] * self.ni
930
+ self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
931
+ if cache:
932
+ self.cache_images(cache)
933
+
934
+ # transforms
935
+ self.transforms = self.build_transforms(hyp=hyp)
936
+
937
+ def get_img_files(self, img_path):
938
+ """Read image files."""
939
+ try:
940
+ f = [] # image files
941
+ for p in img_path if isinstance(img_path, list) else [img_path]:
942
+ p = Path(p) # os-agnostic
943
+ if p.is_dir(): # dir
944
+ f += glob.glob(str(p / "**" / "*.*"), recursive=True)
945
+ # f = list(p.rglob('*.*')) # pathlib
946
+ elif p.is_file(): # file
947
+ with open(p) as t:
948
+ t = t.read().strip().splitlines()
949
+ parent = str(p.parent) + os.sep
950
+ f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
951
+ # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
952
+ else:
953
+ raise FileNotFoundError(f"{self.prefix}{p} does not exist")
954
+ im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
955
+ # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
956
+ assert im_files, f"{self.prefix}No images found"
957
+ except Exception as e:
958
+ raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n") from e
959
+ return im_files
960
+
961
+ def load_image(self, i):
962
+ # Loads 1 image from dataset index 'i', returns (im, resized hw)
963
+ im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
964
+ if im is None: # not cached in RAM
965
+ if fn.exists(): # load npy
966
+ im = np.load(fn)
967
+ else: # read image
968
+ im = cv2.imread(f) # BGR
969
+ if im is None:
970
+ raise FileNotFoundError(f"Image Not Found {f}")
971
+ h0, w0 = im.shape[:2] # orig hw
972
+ r = self.imgsz / max(h0, w0) # ratio
973
+ if r != 1: # if sizes are not equal
974
+ interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
975
+ im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
976
+ return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
977
+ return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
978
+
979
+ def cache_images(self, cache):
980
+ # cache images to memory or disk
981
+ gb = 0 # Gigabytes of cached images
982
+ self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
983
+ fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
984
+ with ThreadPool(NUM_THREADS) as pool:
985
+ results = pool.imap(fcn, range(self.ni))
986
+ pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT)
987
+ for i, x in pbar:
988
+ if cache == "disk":
989
+ gb += self.npy_files[i].stat().st_size
990
+ else: # 'ram'
991
+ self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
992
+ gb += self.ims[i].nbytes
993
+ pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
994
+ pbar.close()
995
+
996
+ def cache_images_to_disk(self, i):
997
+ # Saves an image as an *.npy file for faster loading
998
+ f = self.npy_files[i]
999
+ if not f.exists():
1000
+ np.save(f.as_posix(), cv2.imread(self.im_files[i]))
1001
+
1002
+ def set_rectangle(self):
1003
+ bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
1004
+ nb = bi[-1] + 1 # number of batches
1005
+
1006
+ s = np.array([x.pop("shape") for x in self.labels]) # hw
1007
+ ar = s[:, 0] / s[:, 1] # aspect ratio
1008
+ irect = ar.argsort()
1009
+ self.im_files = [self.im_files[i] for i in irect]
1010
+ self.labels = [self.labels[i] for i in irect]
1011
+ ar = ar[irect]
1012
+
1013
+ # Set training image shapes
1014
+ shapes = [[1, 1]] * nb
1015
+ for i in range(nb):
1016
+ ari = ar[bi == i]
1017
+ mini, maxi = ari.min(), ari.max()
1018
+ if maxi < 1:
1019
+ shapes[i] = [maxi, 1]
1020
+ elif mini > 1:
1021
+ shapes[i] = [1, 1 / mini]
1022
+
1023
+ self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
1024
+ self.batch = bi # batch index of image
1025
+
1026
+ def __getitem__(self, index):
1027
+ return self.transforms(self.get_label_info(index))
1028
+
1029
+ def get_label_info(self, index):
1030
+ label = self.labels[index].copy()
1031
+ label.pop("shape", None) # shape is for rect, remove it
1032
+ label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
1033
+ label["ratio_pad"] = (
1034
+ label["resized_shape"][0] / label["ori_shape"][0],
1035
+ label["resized_shape"][1] / label["ori_shape"][1],
1036
+ ) # for evaluation
1037
+ if self.rect:
1038
+ label["rect_shape"] = self.batch_shapes[self.batch[index]]
1039
+ label = self.update_labels_info(label)
1040
+ return label
1041
+
1042
+ def __len__(self):
1043
+ return len(self.labels)
1044
+
1045
+ def update_labels_info(self, label):
1046
+ """custom your label format here"""
1047
+ return label
1048
+
1049
+ def build_transforms(self, hyp=None):
1050
+ """Users can custom augmentations here
1051
+ like:
1052
+ if self.augment:
1053
+ # training transforms
1054
+ return Compose([])
1055
+ else:
1056
+ # val transforms
1057
+ return Compose([])
1058
+ """
1059
+ raise NotImplementedError
1060
+
1061
+ def get_labels(self):
1062
+ """Users can custom their own format here.
1063
+ Make sure your output is a list with each element like below:
1064
+ dict(
1065
+ im_file=im_file,
1066
+ shape=shape, # format: (height, width)
1067
+ cls=cls,
1068
+ bboxes=bboxes, # xywh
1069
+ segments=segments, # xy
1070
+ keypoints=keypoints, # xy
1071
+ normalized=True, # or False
1072
+ bbox_format="xyxy", # or xywh, ltwh
1073
+ )
1074
+ """
1075
+ raise NotImplementedError
1076
+
1077
+
1078
+ def img2label_paths(img_paths):
1079
+ # Define label paths as a function of image paths
1080
+ sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
1081
+ return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
1082
+
1083
+
1084
+ def get_hash(paths):
1085
+ # Returns a single hash value of a list of paths (files or dirs)
1086
+ size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
1087
+ h = hashlib.md5(str(size).encode()) # hash sizes
1088
+ h.update("".join(paths).encode()) # hash paths
1089
+ return h.hexdigest() # return hash
1090
+
1091
+
1092
+ class Compose:
1093
+
1094
+ def __init__(self, transforms):
1095
+ self.transforms = transforms
1096
+
1097
+ def __call__(self, data):
1098
+ for t in self.transforms:
1099
+ data = t(data)
1100
+ return data
1101
+
1102
+ def append(self, transform):
1103
+ self.transforms.append(transform)
1104
+
1105
+ def tolist(self):
1106
+ return self.transforms
1107
+
1108
+ def __repr__(self):
1109
+ format_string = f"{self.__class__.__name__}("
1110
+ for t in self.transforms:
1111
+ format_string += "\n"
1112
+ format_string += f" {t}"
1113
+ format_string += "\n)"
1114
+ return format_string
1115
+
1116
+
1117
+ class Format:
1118
+
1119
+ def __init__(self,
1120
+ bbox_format="xywh",
1121
+ normalize=True,
1122
+ return_mask=False,
1123
+ return_keypoint=False,
1124
+ mask_ratio=4,
1125
+ mask_overlap=True,
1126
+ batch_idx=True):
1127
+ self.bbox_format = bbox_format
1128
+ self.normalize = normalize
1129
+ self.return_mask = return_mask # set False when training detection only
1130
+ self.return_keypoint = return_keypoint
1131
+ self.mask_ratio = mask_ratio
1132
+ self.mask_overlap = mask_overlap
1133
+ self.batch_idx = batch_idx # keep the batch indexes
1134
+
1135
+ def __call__(self, labels):
1136
+ img = labels.pop("img")
1137
+ h, w = img.shape[:2]
1138
+ cls = labels.pop("cls")
1139
+ instances = labels.pop("instances")
1140
+ instances.convert_bbox(format=self.bbox_format)
1141
+ instances.denormalize(w, h)
1142
+ nl = len(instances)
1143
+
1144
+ if self.normalize:
1145
+ instances.normalize(w, h)
1146
+ labels["img"] = self._format_img(img)
1147
+ labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
1148
+ labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
1149
+ if self.return_keypoint:
1150
+ labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
1151
+ # then we can use collate_fn
1152
+ if self.batch_idx:
1153
+ labels["batch_idx"] = torch.zeros(nl)
1154
+ return labels
1155
+
1156
+ def _format_img(self, img):
1157
+ if len(img.shape) < 3:
1158
+ img = np.expand_dims(img, -1)
1159
+ img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
1160
+ img = torch.from_numpy(img)
1161
+ return img
1162
+
1163
+ class Bboxes:
1164
+ """Now only numpy is supported"""
1165
+
1166
+ def __init__(self, bboxes, format="xyxy") -> None:
1167
+ assert format in _formats
1168
+ bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
1169
+ assert bboxes.ndim == 2
1170
+ assert bboxes.shape[1] == 4
1171
+ self.bboxes = bboxes
1172
+ self.format = format
1173
+
1174
+ def convert(self, format):
1175
+ assert format in _formats
1176
+ if self.format == format:
1177
+ return
1178
+ elif self.format == "xyxy":
1179
+ if format == "xywh":
1180
+ bboxes = xyxy2xywh(self.bboxes)
1181
+ elif self.format == "xywh":
1182
+ if format == "xyxy":
1183
+ bboxes = xywh2xyxy(self.bboxes)
1184
+ self.bboxes = bboxes
1185
+ self.format = format
1186
+
1187
+ def areas(self):
1188
+ self.convert("xyxy")
1189
+ return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
1190
+
1191
+ def mul(self, scale):
1192
+ """
1193
+ Args:
1194
+ scale (tuple | List | int): the scale for four coords.
1195
+ """
1196
+ assert isinstance(scale, (tuple, list))
1197
+ assert len(scale) == 4
1198
+ self.bboxes[:, 0] *= scale[0]
1199
+ self.bboxes[:, 1] *= scale[1]
1200
+ self.bboxes[:, 2] *= scale[2]
1201
+ self.bboxes[:, 3] *= scale[3]
1202
+
1203
+ def add(self, offset):
1204
+ """
1205
+ Args:
1206
+ offset (tuple | List | int): the offset for four coords.
1207
+ """
1208
+ assert isinstance(offset, (tuple, list))
1209
+ assert len(offset) == 4
1210
+ self.bboxes[:, 0] += offset[0]
1211
+ self.bboxes[:, 1] += offset[1]
1212
+ self.bboxes[:, 2] += offset[2]
1213
+ self.bboxes[:, 3] += offset[3]
1214
+
1215
+ def __len__(self):
1216
+ return len(self.bboxes)
1217
+
1218
+ @classmethod
1219
+ def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
1220
+ """
1221
+ Concatenates a list of Boxes into a single Bboxes
1222
+
1223
+ Arguments:
1224
+ boxes_list (list[Bboxes])
1225
+
1226
+ Returns:
1227
+ Bboxes: the concatenated Boxes
1228
+ """
1229
+ assert isinstance(boxes_list, (list, tuple))
1230
+ if not boxes_list:
1231
+ return cls(np.empty(0))
1232
+ assert all(isinstance(box, Bboxes) for box in boxes_list)
1233
+
1234
+ if len(boxes_list) == 1:
1235
+ return boxes_list[0]
1236
+ return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
1237
+
1238
+ def __getitem__(self, index) -> "Bboxes":
1239
+ """
1240
+ Args:
1241
+ index: int, slice, or a BoolArray
1242
+
1243
+ Returns:
1244
+ Bboxes: Create a new :class:`Bboxes` by indexing.
1245
+ """
1246
+ if isinstance(index, int):
1247
+ return Bboxes(self.bboxes[index].view(1, -1))
1248
+ b = self.bboxes[index]
1249
+ assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
1250
+ return Bboxes(b)
1251
+
1252
+
1253
+ def resample_segments(segments, n=1000):
1254
+ """
1255
+ Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
1256
+
1257
+ Args:
1258
+ segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
1259
+ n (int): number of points to resample the segment to. Defaults to 1000
1260
+
1261
+ Returns:
1262
+ segments (list): the resampled segments.
1263
+ """
1264
+ for i, s in enumerate(segments):
1265
+ s = np.concatenate((s, s[0:1, :]), axis=0)
1266
+ x = np.linspace(0, len(s) - 1, n)
1267
+ xp = np.arange(len(s))
1268
+ segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
1269
+ return segments
1270
+
1271
+
1272
+ class Instances:
1273
+
1274
+ def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
1275
+ """
1276
+ Args:
1277
+ bboxes (ndarray): bboxes with shape [N, 4].
1278
+ segments (list | ndarray): segments.
1279
+ keypoints (ndarray): keypoints with shape [N, 17, 2].
1280
+ """
1281
+ if segments is None:
1282
+ segments = []
1283
+ self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
1284
+ self.keypoints = keypoints
1285
+ self.normalized = normalized
1286
+
1287
+ if len(segments) > 0:
1288
+ # list[np.array(1000, 2)] * num_samples
1289
+ segments = resample_segments(segments)
1290
+ # (N, 1000, 2)
1291
+ segments = np.stack(segments, axis=0)
1292
+ else:
1293
+ segments = np.zeros((0, 1000, 2), dtype=np.float32)
1294
+ self.segments = segments
1295
+
1296
+ def convert_bbox(self, format):
1297
+ self._bboxes.convert(format=format)
1298
+
1299
+ def bbox_areas(self):
1300
+ self._bboxes.areas()
1301
+
1302
+ def scale(self, scale_w, scale_h, bbox_only=False):
1303
+ """this might be similar with denormalize func but without normalized sign"""
1304
+ self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
1305
+ if bbox_only:
1306
+ return
1307
+ self.segments[..., 0] *= scale_w
1308
+ self.segments[..., 1] *= scale_h
1309
+ if self.keypoints is not None:
1310
+ self.keypoints[..., 0] *= scale_w
1311
+ self.keypoints[..., 1] *= scale_h
1312
+
1313
+ def denormalize(self, w, h):
1314
+ if not self.normalized:
1315
+ return
1316
+ self._bboxes.mul(scale=(w, h, w, h))
1317
+ self.segments[..., 0] *= w
1318
+ self.segments[..., 1] *= h
1319
+ if self.keypoints is not None:
1320
+ self.keypoints[..., 0] *= w
1321
+ self.keypoints[..., 1] *= h
1322
+ self.normalized = False
1323
+
1324
+ def normalize(self, w, h):
1325
+ if self.normalized:
1326
+ return
1327
+ self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
1328
+ self.segments[..., 0] /= w
1329
+ self.segments[..., 1] /= h
1330
+ if self.keypoints is not None:
1331
+ self.keypoints[..., 0] /= w
1332
+ self.keypoints[..., 1] /= h
1333
+ self.normalized = True
1334
+
1335
+ def add_padding(self, padw, padh):
1336
+ # handle rect and mosaic situation
1337
+ assert not self.normalized, "you should add padding with absolute coordinates."
1338
+ self._bboxes.add(offset=(padw, padh, padw, padh))
1339
+ self.segments[..., 0] += padw
1340
+ self.segments[..., 1] += padh
1341
+ if self.keypoints is not None:
1342
+ self.keypoints[..., 0] += padw
1343
+ self.keypoints[..., 1] += padh
1344
+
1345
+ def __getitem__(self, index) -> "Instances":
1346
+ """
1347
+ Args:
1348
+ index: int, slice, or a BoolArray
1349
+
1350
+ Returns:
1351
+ Instances: Create a new :class:`Instances` by indexing.
1352
+ """
1353
+ segments = self.segments[index] if len(self.segments) else self.segments
1354
+ keypoints = self.keypoints[index] if self.keypoints is not None else None
1355
+ bboxes = self.bboxes[index]
1356
+ bbox_format = self._bboxes.format
1357
+ return Instances(
1358
+ bboxes=bboxes,
1359
+ segments=segments,
1360
+ keypoints=keypoints,
1361
+ bbox_format=bbox_format,
1362
+ normalized=self.normalized,
1363
+ )
1364
+
1365
+ def flipud(self, h):
1366
+ if self._bboxes.format == "xyxy":
1367
+ y1 = self.bboxes[:, 1].copy()
1368
+ y2 = self.bboxes[:, 3].copy()
1369
+ self.bboxes[:, 1] = h - y2
1370
+ self.bboxes[:, 3] = h - y1
1371
+ else:
1372
+ self.bboxes[:, 1] = h - self.bboxes[:, 1]
1373
+ self.segments[..., 1] = h - self.segments[..., 1]
1374
+ if self.keypoints is not None:
1375
+ self.keypoints[..., 1] = h - self.keypoints[..., 1]
1376
+
1377
+ def fliplr(self, w):
1378
+ if self._bboxes.format == "xyxy":
1379
+ x1 = self.bboxes[:, 0].copy()
1380
+ x2 = self.bboxes[:, 2].copy()
1381
+ self.bboxes[:, 0] = w - x2
1382
+ self.bboxes[:, 2] = w - x1
1383
+ else:
1384
+ self.bboxes[:, 0] = w - self.bboxes[:, 0]
1385
+ self.segments[..., 0] = w - self.segments[..., 0]
1386
+ if self.keypoints is not None:
1387
+ self.keypoints[..., 0] = w - self.keypoints[..., 0]
1388
+
1389
+ def clip(self, w, h):
1390
+ ori_format = self._bboxes.format
1391
+ self.convert_bbox(format="xyxy")
1392
+ self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
1393
+ self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
1394
+ if ori_format != "xyxy":
1395
+ self.convert_bbox(format=ori_format)
1396
+ self.segments[..., 0] = self.segments[..., 0].clip(0, w)
1397
+ self.segments[..., 1] = self.segments[..., 1].clip(0, h)
1398
+ if self.keypoints is not None:
1399
+ self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
1400
+ self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
1401
+
1402
+ def update(self, bboxes, segments=None, keypoints=None):
1403
+ new_bboxes = Bboxes(bboxes, format=self._bboxes.format)
1404
+ self._bboxes = new_bboxes
1405
+ if segments is not None:
1406
+ self.segments = segments
1407
+ if keypoints is not None:
1408
+ self.keypoints = keypoints
1409
+
1410
+ def __len__(self):
1411
+ return len(self.bboxes)
1412
+
1413
+ @classmethod
1414
+ def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
1415
+ """
1416
+ Concatenates a list of Boxes into a single Bboxes
1417
+
1418
+ Arguments:
1419
+ instances_list (list[Bboxes])
1420
+ axis
1421
+
1422
+ Returns:
1423
+ Boxes: the concatenated Boxes
1424
+ """
1425
+ assert isinstance(instances_list, (list, tuple))
1426
+ if not instances_list:
1427
+ return cls(np.empty(0))
1428
+ assert all(isinstance(instance, Instances) for instance in instances_list)
1429
+
1430
+ if len(instances_list) == 1:
1431
+ return instances_list[0]
1432
+
1433
+ use_keypoint = instances_list[0].keypoints is not None
1434
+ bbox_format = instances_list[0]._bboxes.format
1435
+ normalized = instances_list[0].normalized
1436
+
1437
+ cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
1438
+ cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
1439
+ cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
1440
+ return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)
1441
+
1442
+ @property
1443
+ def bboxes(self):
1444
+ return self._bboxes.bboxes
1445
+
1446
+
1447
+ def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
1448
+ """
1449
+ Check if a directory is writeable.
1450
+
1451
+ Args:
1452
+ dir_path (str) or (Path): The path to the directory.
1453
+
1454
+ Returns:
1455
+ bool: True if the directory is writeable, False otherwise.
1456
+ """
1457
+ try:
1458
+ with tempfile.TemporaryFile(dir=dir_path):
1459
+ pass
1460
+ return True
1461
+ except OSError:
1462
+ return False
1463
+
1464
+
1465
+ class YOLODataset(BaseDataset):
1466
+ cache_version = '1.0.1' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
1467
+ rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
1468
+ """YOLO Dataset.
1469
+ Args:
1470
+ img_path (str): image path.
1471
+ prefix (str): prefix.
1472
+ """
1473
+
1474
+ def __init__(self,
1475
+ img_path,
1476
+ imgsz=640,
1477
+ cache=False,
1478
+ augment=True,
1479
+ hyp=None,
1480
+ prefix="",
1481
+ rect=False,
1482
+ batch_size=None,
1483
+ stride=32,
1484
+ pad=0.0,
1485
+ single_cls=False,
1486
+ use_segments=False,
1487
+ use_keypoints=False,
1488
+ names=None):
1489
+ self.use_segments = use_segments
1490
+ self.use_keypoints = use_keypoints
1491
+ self.names = names
1492
+ assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
1493
+ super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
1494
+
1495
+ def cache_labels(self, path=Path("./labels.cache")):
1496
+ # Cache dataset labels, check images and read shapes
1497
+ if path.exists():
1498
+ path.unlink() # remove *.cache file if exists
1499
+ x = {"labels": []}
1500
+ nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
1501
+ desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
1502
+ total = len(self.im_files)
1503
+ with ThreadPool(NUM_THREADS) as pool:
1504
+ results = pool.imap(func=verify_image_label,
1505
+ iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
1506
+ repeat(self.use_keypoints), repeat(len(self.names))))
1507
+ pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
1508
+ for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
1509
+ nm += nm_f
1510
+ nf += nf_f
1511
+ ne += ne_f
1512
+ nc += nc_f
1513
+ if im_file:
1514
+ x["labels"].append(
1515
+ dict(
1516
+ im_file=im_file,
1517
+ shape=shape,
1518
+ cls=lb[:, 0:1], # n, 1
1519
+ bboxes=lb[:, 1:], # n, 4
1520
+ segments=segments,
1521
+ keypoints=keypoint,
1522
+ normalized=True,
1523
+ bbox_format="xywh"))
1524
+ if msg:
1525
+ msgs.append(msg)
1526
+ pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
1527
+ pbar.close()
1528
+
1529
+ if msgs:
1530
+ LOGGER.info("\n".join(msgs))
1531
+ x["hash"] = get_hash(self.label_files + self.im_files)
1532
+ x["results"] = nf, nm, ne, nc, len(self.im_files)
1533
+ x["msgs"] = msgs # warnings
1534
+ x["version"] = self.cache_version # cache version
1535
+ self.im_files = [lb["im_file"] for lb in x["labels"]] # update im_files
1536
+ if is_dir_writeable(path.parent):
1537
+ np.save(str(path), x) # save cache for next time
1538
+ path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
1539
+ LOGGER.info(f"{self.prefix}New cache created: {path}")
1540
+ else:
1541
+ LOGGER.warning(f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable") # not writeable
1542
+ return x
1543
+
1544
+ def get_labels(self):
1545
+ self.label_files = img2label_paths(self.im_files)
1546
+ cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
1547
+ try:
1548
+ cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
1549
+ assert cache["version"] == self.cache_version # matches current version
1550
+ assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
1551
+ except (FileNotFoundError, AssertionError, AttributeError):
1552
+ cache, exists = self.cache_labels(cache_path), False # run cache ops
1553
+
1554
+ # Display cache
1555
+ nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
1556
+ if exists:
1557
+ d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
1558
+ tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
1559
+ if cache["msgs"]:
1560
+ LOGGER.info("\n".join(cache["msgs"])) # display warnings
1561
+
1562
+ # Read cache
1563
+ [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
1564
+ labels = cache["labels"]
1565
+
1566
+ # Check if the dataset is all boxes or all segments
1567
+ len_cls = sum(len(lb["cls"]) for lb in labels)
1568
+ len_boxes = sum(len(lb["bboxes"]) for lb in labels)
1569
+ len_segments = sum(len(lb["segments"]) for lb in labels)
1570
+ if len_segments and len_boxes != len_segments:
1571
+ LOGGER.warning(
1572
+ f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
1573
+ f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
1574
+ "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")
1575
+ for lb in labels:
1576
+ lb["segments"] = []
1577
+ return labels
1578
+
1579
+
1580
+ def build_transforms(self, hyp=None):
1581
+ transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
1582
+ transforms.append(
1583
+ Format(bbox_format="xywh",
1584
+ normalize=True,
1585
+ return_mask=self.use_segments,
1586
+ return_keypoint=self.use_keypoints,
1587
+ batch_idx=True,
1588
+ mask_ratio=hyp.mask_ratio,
1589
+ mask_overlap=hyp.overlap_mask))
1590
+ return transforms
1591
+
1592
+ def close_mosaic(self, hyp):
1593
+ hyp.mosaic = 0.0 # set mosaic ratio=0.0
1594
+ hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
1595
+ hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
1596
+ self.transforms = self.build_transforms(hyp)
1597
+
1598
+ def update_labels_info(self, label):
1599
+ """custom your label format here"""
1600
+ # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
1601
+ # we can make it also support classification and semantic segmentation by add or remove some dict keys there.
1602
+ bboxes = label.pop("bboxes")
1603
+ segments = label.pop("segments")
1604
+ keypoints = label.pop("keypoints", None)
1605
+ bbox_format = label.pop("bbox_format")
1606
+ normalized = label.pop("normalized")
1607
+ label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
1608
+ return label
1609
+
1610
+ @staticmethod
1611
+ def collate_fn(batch):
1612
+ new_batch = {}
1613
+ keys = batch[0].keys()
1614
+ values = list(zip(*[list(b.values()) for b in batch]))
1615
+ for i, k in enumerate(keys):
1616
+ value = values[i]
1617
+ if k == "img":
1618
+ value = torch.stack(value, 0)
1619
+ if k in ["masks", "keypoints", "bboxes", "cls"]:
1620
+ value = torch.cat(value, 0)
1621
+ new_batch[k] = value
1622
+ new_batch["batch_idx"] = list(new_batch["batch_idx"])
1623
+ for i in range(len(new_batch["batch_idx"])):
1624
+ new_batch["batch_idx"][i] += i # add target image index for build_targets()
1625
+ new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
1626
+ return new_batch
1627
+
1628
+
1629
+ class DFL(nn.Module):
1630
+ # Integral module of Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
1631
+ def __init__(self, c1=16):
1632
+ super().__init__()
1633
+ self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
1634
+ x = torch.arange(c1, dtype=torch.float)
1635
+ self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
1636
+ self.c1 = c1
1637
+
1638
+ def forward(self, x):
1639
+ b, c, a = x.shape # batch, channels, anchors
1640
+ return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(
1641
+ b, 4, a
1642
+ )
1643
+
1644
+
1645
+ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
1646
+ """Transform distance(ltrb) to box(xywh or xyxy)."""
1647
+ lt, rb = torch.split(distance, 2, dim)
1648
+ x1y1 = anchor_points - lt
1649
+ x2y2 = anchor_points + rb
1650
+ if xywh:
1651
+ c_xy = (x1y1 + x2y2) / 2
1652
+ wh = x2y2 - x1y1
1653
+ return torch.cat((c_xy, wh), dim) # xywh bbox
1654
+ return torch.cat((x1y1, x2y2), dim) # xyxy bbox
1655
+
1656
+
1657
+ def post_process(x):
1658
+ dfl = DFL(16)
1659
+ anchors = torch.tensor(
1660
+ np.load(
1661
+ "./anchors.npy",
1662
+ allow_pickle=True,
1663
+ )
1664
+ )
1665
+ strides = torch.tensor(
1666
+ np.load(
1667
+ "./strides.npy",
1668
+ allow_pickle=True,
1669
+ )
1670
+ )
1671
+ box, cls = torch.cat([xi.view(x[0].shape[0], 144, -1) for xi in x], 2).split(
1672
+ (16 * 4, 80), 1
1673
+ )
1674
+ dbox = dist2bbox(dfl(box), anchors.unsqueeze(0), xywh=True, dim=1) * strides
1675
+ y = torch.cat((dbox, cls.sigmoid()), 1)
1676
+ return y, x
1677
+
1678
+
1679
+ def smooth(y, f=0.05):
1680
+ # Box filter of fraction f
1681
+ nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
1682
+ p = np.ones(nf // 2) # ones padding
1683
+ yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
1684
+ return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
1685
+
1686
+
1687
+ def compute_ap(recall, precision):
1688
+ """ Compute the average precision, given the recall and precision curves
1689
+ # Arguments
1690
+ recall: The recall curve (list)
1691
+ precision: The precision curve (list)
1692
+ # Returns
1693
+ Average precision, precision curve, recall curve
1694
+ """
1695
+
1696
+ # Append sentinel values to beginning and end
1697
+ mrec = np.concatenate(([0.0], recall, [1.0]))
1698
+ mpre = np.concatenate(([1.0], precision, [0.0]))
1699
+
1700
+ # Compute the precision envelope
1701
+ mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
1702
+
1703
+ # Integrate area under curve
1704
+ method = 'interp' # methods: 'continuous', 'interp'
1705
+ if method == 'interp':
1706
+ x = np.linspace(0, 1, 101) # 101-point interp (COCO)
1707
+ ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
1708
+ else: # 'continuous'
1709
+ i = np.where(mrec[1:] != mrec[:-1])[0] # points where x-axis (recall) changes
1710
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
1711
+
1712
+ return ap, mpre, mrec
1713
+
1714
+
1715
+ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=""):
1716
+ """ Compute the average precision, given the recall and precision curves.
1717
+ Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
1718
+ # Arguments
1719
+ tp: True positives (nparray, nx1 or nx10).
1720
+ conf: Objectness value from 0-1 (nparray).
1721
+ pred_cls: Predicted object classes (nparray).
1722
+ target_cls: True object classes (nparray).
1723
+ plot: Plot precision-recall curve at mAP@0.5
1724
+ save_dir: Plot save directory
1725
+ # Returns
1726
+ The average precision as computed in py-faster-rcnn.
1727
+ """
1728
+
1729
+ # Sort by objectness
1730
+ i = np.argsort(-conf)
1731
+ tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
1732
+
1733
+ # Find unique classes
1734
+ unique_classes, nt = np.unique(target_cls, return_counts=True)
1735
+ nc = unique_classes.shape[0] # number of classes, number of detections
1736
+
1737
+ # Create Precision-Recall curve and compute AP for each class
1738
+ px, py = np.linspace(0, 1, 1000), [] # for plotting
1739
+ ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
1740
+ for ci, c in enumerate(unique_classes):
1741
+ i = pred_cls == c
1742
+ n_l = nt[ci] # number of labels
1743
+ n_p = i.sum() # number of predictions
1744
+ if n_p == 0 or n_l == 0:
1745
+ continue
1746
+
1747
+ # Accumulate FPs and TPs
1748
+ fpc = (1 - tp[i]).cumsum(0)
1749
+ tpc = tp[i].cumsum(0)
1750
+
1751
+ # Recall
1752
+ recall = tpc / (n_l + eps) # recall curve
1753
+ r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
1754
+
1755
+ # Precision
1756
+ precision = tpc / (tpc + fpc) # precision curve
1757
+ p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
1758
+
1759
+ # AP from recall-precision curve
1760
+ for j in range(tp.shape[1]):
1761
+ ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
1762
+ if plot and j == 0:
1763
+ py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
1764
+
1765
+ # Compute F1 (harmonic mean of precision and recall)
1766
+ f1 = 2 * p * r / (p + r + eps)
1767
+ names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
1768
+ names = dict(enumerate(names)) # to dict
1769
+
1770
+ i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
1771
+ p, r, f1 = p[:, i], r[:, i], f1[:, i]
1772
+ tp = (r * nt).round() # true positives
1773
+ fp = (tp / (p + eps) - tp).round() # false positives
1774
+ return tp, fp, p, r, f1, ap, unique_classes.astype(int)
1775
+
1776
+
1777
+ class Metric:
1778
+
1779
+ def __init__(self) -> None:
1780
+ self.p = [] # (nc, )
1781
+ self.r = [] # (nc, )
1782
+ self.f1 = [] # (nc, )
1783
+ self.all_ap = [] # (nc, 10)
1784
+ self.ap_class_index = [] # (nc, )
1785
+ self.nc = 0
1786
+
1787
+ @property
1788
+ def ap50(self):
1789
+ """AP@0.5 of all classes.
1790
+ Return:
1791
+ (nc, ) or [].
1792
+ """
1793
+ return self.all_ap[:, 0] if len(self.all_ap) else []
1794
+
1795
+ @property
1796
+ def ap(self):
1797
+ """AP@0.5:0.95
1798
+ Return:
1799
+ (nc, ) or [].
1800
+ """
1801
+ return self.all_ap.mean(1) if len(self.all_ap) else []
1802
+
1803
+ @property
1804
+ def mp(self):
1805
+ """mean precision of all classes.
1806
+ Return:
1807
+ float.
1808
+ """
1809
+ return self.p.mean() if len(self.p) else 0.0
1810
+
1811
+ @property
1812
+ def mr(self):
1813
+ """mean recall of all classes.
1814
+ Return:
1815
+ float.
1816
+ """
1817
+ return self.r.mean() if len(self.r) else 0.0
1818
+
1819
+ @property
1820
+ def map50(self):
1821
+ """Mean AP@0.5 of all classes.
1822
+ Return:
1823
+ float.
1824
+ """
1825
+ return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
1826
+
1827
+ @property
1828
+ def map75(self):
1829
+ """Mean AP@0.75 of all classes.
1830
+ Return:
1831
+ float.
1832
+ """
1833
+ return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
1834
+
1835
+ @property
1836
+ def map(self):
1837
+ """Mean AP@0.5:0.95 of all classes.
1838
+ Return:
1839
+ float.
1840
+ """
1841
+ return self.all_ap.mean() if len(self.all_ap) else 0.0
1842
+
1843
+ def mean_results(self):
1844
+ """Mean of results, return mp, mr, map50, map"""
1845
+ return [self.mp, self.mr, self.map50, self.map]
1846
+
1847
+ def class_result(self, i):
1848
+ """class-aware result, return p[i], r[i], ap50[i], ap[i]"""
1849
+ return self.p[i], self.r[i], self.ap50[i], self.ap[i]
1850
+
1851
+ @property
1852
+ def maps(self):
1853
+ """mAP of each class"""
1854
+ maps = np.zeros(self.nc) + self.map
1855
+ for i, c in enumerate(self.ap_class_index):
1856
+ maps[c] = self.ap[i]
1857
+ return maps
1858
+
1859
+ def fitness(self):
1860
+ # Model fitness as a weighted combination of metrics
1861
+ w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
1862
+ return (np.array(self.mean_results()) * w).sum()
1863
+
1864
+ def update(self, results):
1865
+ """
1866
+ Args:
1867
+ results: tuple(p, r, ap, f1, ap_class)
1868
+ """
1869
+ self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results
1870
+
1871
+
1872
+ class DetMetrics:
1873
+
1874
+ def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
1875
+ self.save_dir = save_dir
1876
+ self.plot = plot
1877
+ self.names = names
1878
+ self.box = Metric()
1879
+
1880
+ def process(self, tp, conf, pred_cls, target_cls):
1881
+ results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
1882
+ names=self.names)[2:]
1883
+ self.box.nc = len(self.names)
1884
+ self.box.update(results)
1885
+
1886
+ @property
1887
+ def keys(self):
1888
+ return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
1889
+
1890
+ def mean_results(self):
1891
+ return self.box.mean_results()
1892
+
1893
+ def class_result(self, i):
1894
+ return self.box.class_result(i)
1895
+
1896
+ @property
1897
+ def maps(self):
1898
+ return self.box.maps
1899
+
1900
+ @property
1901
+ def fitness(self):
1902
+ return self.box.fitness()
1903
+
1904
+ @property
1905
+ def ap_class_index(self):
1906
+ return self.box.ap_class_index
1907
+
1908
+ @property
1909
+ def results_dict(self):
1910
+ return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
1911
+
1912
+
1913
+ def increment_path(path, exist_ok=False, sep='', mkdir=False):
1914
+ """
1915
+ Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
1916
+
1917
+ If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
1918
+ the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
1919
+ number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
1920
+ directory if it does not already exist.
1921
+
1922
+ Args:
1923
+ path (str or pathlib.Path): Path to increment.
1924
+ exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False.
1925
+ sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string.
1926
+ mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False.
1927
+
1928
+ Returns:
1929
+ pathlib.Path: Incremented path.
1930
+ """
1931
+ path = Path(path) # os-agnostic
1932
+ if path.exists() and not exist_ok:
1933
+ path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
1934
+
1935
+ # Method 1
1936
+ for n in range(2, 9999):
1937
+ p = f'{path}{sep}{n}{suffix}' # increment path
1938
+ if not os.path.exists(p): #
1939
+ break
1940
+ path = Path(p)
1941
+
1942
+ if mkdir:
1943
+ path.mkdir(parents=True, exist_ok=True) # make directory
1944
+
1945
+ return path
1946
+
1947
+
1948
+ def cfg2dict(cfg):
1949
+ """
1950
+ Convert a configuration object to a dictionary.
1951
+
1952
+ This function converts a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
1953
+
1954
+ Inputs:
1955
+ cfg (str) or (Path) or (SimpleNamespace): Configuration object to be converted to a dictionary.
1956
+
1957
+ Returns:
1958
+ cfg (dict): Configuration object in dictionary format.
1959
+ """
1960
+ if isinstance(cfg, (str, Path)):
1961
+ cfg = yaml_load(cfg) # load dict
1962
+ elif isinstance(cfg, SimpleNamespace):
1963
+ cfg = vars(cfg) # convert to dict
1964
+ return cfg
1965
+
1966
+
1967
+ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = None, overrides: Dict = None):
1968
+ """
1969
+ Load and merge configuration data from a file or dictionary.
1970
+
1971
+ Args:
1972
+ cfg (str) or (Path) or (Dict) or (SimpleNamespace): Configuration data.
1973
+ overrides (str) or (Dict), optional: Overrides in the form of a file name or a dictionary. Default is None.
1974
+
1975
+ Returns:
1976
+ (SimpleNamespace): Training arguments namespace.
1977
+ """
1978
+ cfg = cfg2dict(cfg)
1979
+
1980
+ # Merge overrides
1981
+ if overrides:
1982
+ overrides = cfg2dict(overrides)
1983
+ cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
1984
+
1985
+ # Special handling for numeric project/names
1986
+ for k in 'project', 'name':
1987
+ if k in cfg and isinstance(cfg[k], (int, float)):
1988
+ cfg[k] = str(cfg[k])
1989
+
1990
+ # Type and Value checks
1991
+ for k, v in cfg.items():
1992
+ if v is not None: # None values may be from optional args
1993
+ if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
1994
+ raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
1995
+ f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
1996
+ elif k in CFG_FRACTION_KEYS:
1997
+ if not isinstance(v, (int, float)):
1998
+ raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
1999
+ f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
2000
+ if not (0.0 <= v <= 1.0):
2001
+ raise ValueError(f"'{k}={v}' is an invalid value. "
2002
+ f"Valid '{k}' values are between 0.0 and 1.0.")
2003
+ elif k in CFG_INT_KEYS and not isinstance(v, int):
2004
+ raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
2005
+ f"'{k}' must be an int (i.e. '{k}=0')")
2006
+ elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
2007
+ raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
2008
+ f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
2009
+
2010
+ # Return instance
2011
+ return IterableSimpleNamespace(**cfg)
2012
+
2013
+
2014
+ def clip_boxes(boxes, shape):
2015
+ """
2016
+ It takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the
2017
+ shape
2018
+
2019
+ Args:
2020
+ boxes (torch.Tensor): the bounding boxes to clip
2021
+ shape (tuple): the shape of the image
2022
+ """
2023
+ if isinstance(boxes, torch.Tensor): # faster individually
2024
+ boxes[..., 0].clamp_(0, shape[1]) # x1
2025
+ boxes[..., 1].clamp_(0, shape[0]) # y1
2026
+ boxes[..., 2].clamp_(0, shape[1]) # x2
2027
+ boxes[..., 3].clamp_(0, shape[0]) # y2
2028
+ else: # np.array (faster grouped)
2029
+ boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
2030
+ boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
2031
+
2032
+
2033
+ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
2034
+ """
2035
+ Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
2036
+ (img1_shape) to the shape of a different image (img0_shape).
2037
+
2038
+ Args:
2039
+ img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
2040
+ boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
2041
+ img0_shape (tuple): the shape of the target image, in the format of (height, width).
2042
+ ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
2043
+ calculated based on the size difference between the two images.
2044
+
2045
+ Returns:
2046
+ boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
2047
+ """
2048
+ if ratio_pad is None: # calculate from img0_shape
2049
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
2050
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
2051
+ else:
2052
+ gain = ratio_pad[0][0]
2053
+ pad = ratio_pad[1]
2054
+
2055
+ boxes[..., [0, 2]] -= pad[0] # x padding
2056
+ boxes[..., [1, 3]] -= pad[1] # y padding
2057
+ boxes[..., :4] /= gain
2058
+ clip_boxes(boxes, img0_shape)
2059
+ return boxes
2060
+
2061
+
2062
+ def exif_size(img):
2063
+ # Returns exif-corrected PIL size
2064
+ s = img.size # (width, height)
2065
+ with contextlib.suppress(Exception):
2066
+ rotation = dict(img._getexif().items())[orientation]
2067
+ if rotation in [6, 8]: # rotation 270 or 90
2068
+ s = (s[1], s[0])
2069
+ return s
2070
+
2071
+
2072
+ def verify_image_label(args):
2073
+ # Verify one image-label pair
2074
+ im_file, lb_file, prefix, keypoint, num_cls = args
2075
+ # number (missing, found, empty, corrupt), message, segments, keypoints
2076
+ nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
2077
+ try:
2078
+ # verify images
2079
+ im = Image.open(im_file)
2080
+ im.verify() # PIL verify
2081
+ shape = exif_size(im) # image size
2082
+ shape = (shape[1], shape[0]) # hw
2083
+ assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
2084
+ assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
2085
+ if im.format.lower() in ("jpg", "jpeg"):
2086
+ with open(im_file, "rb") as f:
2087
+ f.seek(-2, 2)
2088
+
2089
+ # verify labels
2090
+ if os.path.isfile(lb_file):
2091
+ nf = 1 # label found
2092
+ with open(lb_file) as f:
2093
+ lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
2094
+ if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
2095
+ classes = np.array([x[0] for x in lb], dtype=np.float32)
2096
+ segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
2097
+ lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
2098
+ lb = np.array(lb, dtype=np.float32)
2099
+ nl = len(lb)
2100
+ if nl:
2101
+ if keypoint:
2102
+ assert lb.shape[1] == 56, "labels require 56 columns each"
2103
+ assert (lb[:, 5::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
2104
+ assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
2105
+ kpts = np.zeros((lb.shape[0], 39))
2106
+ for i in range(len(lb)):
2107
+ kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, 3)) # remove occlusion param from GT
2108
+ kpts[i] = np.hstack((lb[i, :5], kpt))
2109
+ lb = kpts
2110
+ assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter"
2111
+ else:
2112
+ assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
2113
+ assert (lb[:, 1:] <= 1).all(), \
2114
+ f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
2115
+ # All labels
2116
+ max_cls = int(lb[:, 0].max()) # max label count
2117
+ assert max_cls <= num_cls, \
2118
+ f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
2119
+ f'Possible class labels are 0-{num_cls - 1}'
2120
+ assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
2121
+ _, i = np.unique(lb, axis=0, return_index=True)
2122
+ if len(i) < nl: # duplicate row check
2123
+ lb = lb[i] # remove duplicates
2124
+ if segments:
2125
+ segments = [segments[x] for x in i]
2126
+ msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
2127
+ else:
2128
+ ne = 1 # label empty
2129
+ lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
2130
+ else:
2131
+ nm = 1 # label missing
2132
+ lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
2133
+ if keypoint:
2134
+ keypoints = lb[:, 5:].reshape(-1, 17, 2)
2135
+ lb = lb[:, :5]
2136
+ return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
2137
+ except Exception as e:
2138
+ nc = 1
2139
+ msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
2140
+ return [None, None, None, None, None, nm, nf, ne, nc, msg]
yolov8m_qat.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b770e88b358ad24cc60e7b8bbc00b09bb1e0308f65f45cdcea2a1dfc1301077
3
+ size 103874610