app
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- ultralytics/__init__.py +11 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/datasets/Argoverse.yaml +73 -0
- ultralytics/datasets/GlobalWheat2020.yaml +54 -0
- ultralytics/datasets/ImageNet.yaml +2025 -0
- ultralytics/datasets/Objects365.yaml +443 -0
- ultralytics/datasets/SKU-110K.yaml +58 -0
- ultralytics/datasets/VOC.yaml +100 -0
- ultralytics/datasets/VisDrone.yaml +73 -0
- ultralytics/datasets/coco-pose.yaml +38 -0
- ultralytics/datasets/coco.yaml +115 -0
- ultralytics/datasets/coco128-seg.yaml +101 -0
- ultralytics/datasets/coco128.yaml +101 -0
- ultralytics/datasets/coco8-pose.yaml +25 -0
- ultralytics/datasets/coco8-seg.yaml +101 -0
- ultralytics/datasets/coco8.yaml +101 -0
- ultralytics/datasets/xView.yaml +153 -0
- ultralytics/hub/__init__.py +113 -0
- ultralytics/hub/auth.py +139 -0
- ultralytics/hub/session.py +189 -0
- ultralytics/hub/utils.py +217 -0
- ultralytics/models/README.md +45 -0
- ultralytics/models/rt-detr/rt-detr-l.yaml +50 -0
- ultralytics/models/rt-detr/rt-detr-x.yaml +54 -0
- ultralytics/models/v3/yolov3-spp.yaml +48 -0
- ultralytics/models/v3/yolov3-tiny.yaml +39 -0
- ultralytics/models/v3/yolov3.yaml +48 -0
- ultralytics/models/v5/yolov5-p6.yaml +61 -0
- ultralytics/models/v5/yolov5.yaml +50 -0
- ultralytics/models/v6/yolov6.yaml +53 -0
- ultralytics/models/v8/yolov8-cls.yaml +29 -0
- ultralytics/models/v8/yolov8-p2.yaml +54 -0
- ultralytics/models/v8/yolov8-p6.yaml +56 -0
- ultralytics/models/v8/yolov8-pose-p6.yaml +57 -0
- ultralytics/models/v8/yolov8-pose.yaml +47 -0
- ultralytics/models/v8/yolov8-seg.yaml +46 -0
- ultralytics/models/v8/yolov8.yaml +46 -0
- ultralytics/nn/__init__.py +9 -0
- ultralytics/nn/autobackend.py +455 -0
- ultralytics/nn/autoshape.py +243 -0
- ultralytics/nn/modules/__init__.py +29 -0
- ultralytics/nn/modules/block.py +304 -0
- ultralytics/nn/modules/conv.py +297 -0
- ultralytics/nn/modules/head.py +382 -0
- ultralytics/nn/modules/transformer.py +389 -0
- ultralytics/nn/modules/utils.py +78 -0
- ultralytics/nn/tasks.py +773 -0
- ultralytics/tracker/README.md +86 -0
- ultralytics/tracker/__init__.py +6 -0
ultralytics/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
__version__ = '8.0.114'
|
4 |
+
|
5 |
+
from ultralytics.hub import start
|
6 |
+
from ultralytics.vit.rtdetr import RTDETR
|
7 |
+
from ultralytics.vit.sam import SAM
|
8 |
+
from ultralytics.yolo.engine.model import YOLO
|
9 |
+
from ultralytics.yolo.utils.checks import check_yolo as checks
|
10 |
+
|
11 |
+
__all__ = '__version__', 'YOLO', 'SAM', 'RTDETR', 'checks', 'start' # allow simpler import
|
ultralytics/assets/bus.jpg
ADDED
ultralytics/assets/zidane.jpg
ADDED
ultralytics/datasets/Argoverse.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/ by Argo AI
|
3 |
+
# Example usage: yolo train data=Argoverse.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── Argoverse ← downloads here (31.3 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/Argoverse # dataset root dir
|
12 |
+
train: Argoverse-1.1/images/train/ # train images (relative to 'path') 39384 images
|
13 |
+
val: Argoverse-1.1/images/val/ # val images (relative to 'path') 15062 images
|
14 |
+
test: Argoverse-1.1/images/test/ # test images (optional) https://eval.ai/web/challenges/challenge-page/800/overview
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: bus
|
23 |
+
5: truck
|
24 |
+
6: traffic_light
|
25 |
+
7: stop_sign
|
26 |
+
|
27 |
+
|
28 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
29 |
+
download: |
|
30 |
+
import json
|
31 |
+
from tqdm import tqdm
|
32 |
+
from ultralytics.yolo.utils.downloads import download
|
33 |
+
from pathlib import Path
|
34 |
+
|
35 |
+
def argoverse2yolo(set):
|
36 |
+
labels = {}
|
37 |
+
a = json.load(open(set, "rb"))
|
38 |
+
for annot in tqdm(a['annotations'], desc=f"Converting {set} to YOLOv5 format..."):
|
39 |
+
img_id = annot['image_id']
|
40 |
+
img_name = a['images'][img_id]['name']
|
41 |
+
img_label_name = f'{img_name[:-3]}txt'
|
42 |
+
|
43 |
+
cls = annot['category_id'] # instance class id
|
44 |
+
x_center, y_center, width, height = annot['bbox']
|
45 |
+
x_center = (x_center + width / 2) / 1920.0 # offset and scale
|
46 |
+
y_center = (y_center + height / 2) / 1200.0 # offset and scale
|
47 |
+
width /= 1920.0 # scale
|
48 |
+
height /= 1200.0 # scale
|
49 |
+
|
50 |
+
img_dir = set.parents[2] / 'Argoverse-1.1' / 'labels' / a['seq_dirs'][a['images'][annot['image_id']]['sid']]
|
51 |
+
if not img_dir.exists():
|
52 |
+
img_dir.mkdir(parents=True, exist_ok=True)
|
53 |
+
|
54 |
+
k = str(img_dir / img_label_name)
|
55 |
+
if k not in labels:
|
56 |
+
labels[k] = []
|
57 |
+
labels[k].append(f"{cls} {x_center} {y_center} {width} {height}\n")
|
58 |
+
|
59 |
+
for k in labels:
|
60 |
+
with open(k, "w") as f:
|
61 |
+
f.writelines(labels[k])
|
62 |
+
|
63 |
+
|
64 |
+
# Download
|
65 |
+
dir = Path(yaml['path']) # dataset root dir
|
66 |
+
urls = ['https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip']
|
67 |
+
download(urls, dir=dir)
|
68 |
+
|
69 |
+
# Convert
|
70 |
+
annotations_dir = 'Argoverse-HD/annotations/'
|
71 |
+
(dir / 'Argoverse-1.1' / 'tracking').rename(dir / 'Argoverse-1.1' / 'images') # rename 'tracking' to 'images'
|
72 |
+
for d in "train.json", "val.json":
|
73 |
+
argoverse2yolo(dir / annotations_dir / d) # convert VisDrone annotations to YOLO labels
|
ultralytics/datasets/GlobalWheat2020.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# Global Wheat 2020 dataset http://www.global-wheat.com/ by University of Saskatchewan
|
3 |
+
# Example usage: yolo train data=GlobalWheat2020.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── GlobalWheat2020 ← downloads here (7.0 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/GlobalWheat2020 # dataset root dir
|
12 |
+
train: # train images (relative to 'path') 3422 images
|
13 |
+
- images/arvalis_1
|
14 |
+
- images/arvalis_2
|
15 |
+
- images/arvalis_3
|
16 |
+
- images/ethz_1
|
17 |
+
- images/rres_1
|
18 |
+
- images/inrae_1
|
19 |
+
- images/usask_1
|
20 |
+
val: # val images (relative to 'path') 748 images (WARNING: train set contains ethz_1)
|
21 |
+
- images/ethz_1
|
22 |
+
test: # test images (optional) 1276 images
|
23 |
+
- images/utokyo_1
|
24 |
+
- images/utokyo_2
|
25 |
+
- images/nau_1
|
26 |
+
- images/uq_1
|
27 |
+
|
28 |
+
# Classes
|
29 |
+
names:
|
30 |
+
0: wheat_head
|
31 |
+
|
32 |
+
|
33 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
34 |
+
download: |
|
35 |
+
from ultralytics.yolo.utils.downloads import download
|
36 |
+
from pathlib import Path
|
37 |
+
|
38 |
+
# Download
|
39 |
+
dir = Path(yaml['path']) # dataset root dir
|
40 |
+
urls = ['https://zenodo.org/record/4298502/files/global-wheat-codalab-official.zip',
|
41 |
+
'https://github.com/ultralytics/yolov5/releases/download/v1.0/GlobalWheat2020_labels.zip']
|
42 |
+
download(urls, dir=dir)
|
43 |
+
|
44 |
+
# Make Directories
|
45 |
+
for p in 'annotations', 'images', 'labels':
|
46 |
+
(dir / p).mkdir(parents=True, exist_ok=True)
|
47 |
+
|
48 |
+
# Move
|
49 |
+
for p in 'arvalis_1', 'arvalis_2', 'arvalis_3', 'ethz_1', 'rres_1', 'inrae_1', 'usask_1', \
|
50 |
+
'utokyo_1', 'utokyo_2', 'nau_1', 'uq_1':
|
51 |
+
(dir / 'global-wheat-codalab-official' / p).rename(dir / 'images' / p) # move to /images
|
52 |
+
f = (dir / 'global-wheat-codalab-official' / p).with_suffix('.json') # json file
|
53 |
+
if f.exists():
|
54 |
+
f.rename((dir / 'annotations' / p).with_suffix('.json')) # move to /annotations
|
ultralytics/datasets/ImageNet.yaml
ADDED
@@ -0,0 +1,2025 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# ImageNet-1k dataset https://www.image-net.org/index.php by Stanford University
|
3 |
+
# Simplified class names from https://github.com/anishathalye/imagenet-simple-labels
|
4 |
+
# Example usage: yolo train task=classify data=imagenet
|
5 |
+
# parent
|
6 |
+
# ├── ultralytics
|
7 |
+
# └── datasets
|
8 |
+
# └── imagenet ← downloads here (144 GB)
|
9 |
+
|
10 |
+
|
11 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
12 |
+
path: ../datasets/imagenet # dataset root dir
|
13 |
+
train: train # train images (relative to 'path') 1281167 images
|
14 |
+
val: val # val images (relative to 'path') 50000 images
|
15 |
+
test: # test images (optional)
|
16 |
+
|
17 |
+
# Classes
|
18 |
+
names:
|
19 |
+
0: tench
|
20 |
+
1: goldfish
|
21 |
+
2: great white shark
|
22 |
+
3: tiger shark
|
23 |
+
4: hammerhead shark
|
24 |
+
5: electric ray
|
25 |
+
6: stingray
|
26 |
+
7: cock
|
27 |
+
8: hen
|
28 |
+
9: ostrich
|
29 |
+
10: brambling
|
30 |
+
11: goldfinch
|
31 |
+
12: house finch
|
32 |
+
13: junco
|
33 |
+
14: indigo bunting
|
34 |
+
15: American robin
|
35 |
+
16: bulbul
|
36 |
+
17: jay
|
37 |
+
18: magpie
|
38 |
+
19: chickadee
|
39 |
+
20: American dipper
|
40 |
+
21: kite
|
41 |
+
22: bald eagle
|
42 |
+
23: vulture
|
43 |
+
24: great grey owl
|
44 |
+
25: fire salamander
|
45 |
+
26: smooth newt
|
46 |
+
27: newt
|
47 |
+
28: spotted salamander
|
48 |
+
29: axolotl
|
49 |
+
30: American bullfrog
|
50 |
+
31: tree frog
|
51 |
+
32: tailed frog
|
52 |
+
33: loggerhead sea turtle
|
53 |
+
34: leatherback sea turtle
|
54 |
+
35: mud turtle
|
55 |
+
36: terrapin
|
56 |
+
37: box turtle
|
57 |
+
38: banded gecko
|
58 |
+
39: green iguana
|
59 |
+
40: Carolina anole
|
60 |
+
41: desert grassland whiptail lizard
|
61 |
+
42: agama
|
62 |
+
43: frilled-necked lizard
|
63 |
+
44: alligator lizard
|
64 |
+
45: Gila monster
|
65 |
+
46: European green lizard
|
66 |
+
47: chameleon
|
67 |
+
48: Komodo dragon
|
68 |
+
49: Nile crocodile
|
69 |
+
50: American alligator
|
70 |
+
51: triceratops
|
71 |
+
52: worm snake
|
72 |
+
53: ring-necked snake
|
73 |
+
54: eastern hog-nosed snake
|
74 |
+
55: smooth green snake
|
75 |
+
56: kingsnake
|
76 |
+
57: garter snake
|
77 |
+
58: water snake
|
78 |
+
59: vine snake
|
79 |
+
60: night snake
|
80 |
+
61: boa constrictor
|
81 |
+
62: African rock python
|
82 |
+
63: Indian cobra
|
83 |
+
64: green mamba
|
84 |
+
65: sea snake
|
85 |
+
66: Saharan horned viper
|
86 |
+
67: eastern diamondback rattlesnake
|
87 |
+
68: sidewinder
|
88 |
+
69: trilobite
|
89 |
+
70: harvestman
|
90 |
+
71: scorpion
|
91 |
+
72: yellow garden spider
|
92 |
+
73: barn spider
|
93 |
+
74: European garden spider
|
94 |
+
75: southern black widow
|
95 |
+
76: tarantula
|
96 |
+
77: wolf spider
|
97 |
+
78: tick
|
98 |
+
79: centipede
|
99 |
+
80: black grouse
|
100 |
+
81: ptarmigan
|
101 |
+
82: ruffed grouse
|
102 |
+
83: prairie grouse
|
103 |
+
84: peacock
|
104 |
+
85: quail
|
105 |
+
86: partridge
|
106 |
+
87: grey parrot
|
107 |
+
88: macaw
|
108 |
+
89: sulphur-crested cockatoo
|
109 |
+
90: lorikeet
|
110 |
+
91: coucal
|
111 |
+
92: bee eater
|
112 |
+
93: hornbill
|
113 |
+
94: hummingbird
|
114 |
+
95: jacamar
|
115 |
+
96: toucan
|
116 |
+
97: duck
|
117 |
+
98: red-breasted merganser
|
118 |
+
99: goose
|
119 |
+
100: black swan
|
120 |
+
101: tusker
|
121 |
+
102: echidna
|
122 |
+
103: platypus
|
123 |
+
104: wallaby
|
124 |
+
105: koala
|
125 |
+
106: wombat
|
126 |
+
107: jellyfish
|
127 |
+
108: sea anemone
|
128 |
+
109: brain coral
|
129 |
+
110: flatworm
|
130 |
+
111: nematode
|
131 |
+
112: conch
|
132 |
+
113: snail
|
133 |
+
114: slug
|
134 |
+
115: sea slug
|
135 |
+
116: chiton
|
136 |
+
117: chambered nautilus
|
137 |
+
118: Dungeness crab
|
138 |
+
119: rock crab
|
139 |
+
120: fiddler crab
|
140 |
+
121: red king crab
|
141 |
+
122: American lobster
|
142 |
+
123: spiny lobster
|
143 |
+
124: crayfish
|
144 |
+
125: hermit crab
|
145 |
+
126: isopod
|
146 |
+
127: white stork
|
147 |
+
128: black stork
|
148 |
+
129: spoonbill
|
149 |
+
130: flamingo
|
150 |
+
131: little blue heron
|
151 |
+
132: great egret
|
152 |
+
133: bittern
|
153 |
+
134: crane (bird)
|
154 |
+
135: limpkin
|
155 |
+
136: common gallinule
|
156 |
+
137: American coot
|
157 |
+
138: bustard
|
158 |
+
139: ruddy turnstone
|
159 |
+
140: dunlin
|
160 |
+
141: common redshank
|
161 |
+
142: dowitcher
|
162 |
+
143: oystercatcher
|
163 |
+
144: pelican
|
164 |
+
145: king penguin
|
165 |
+
146: albatross
|
166 |
+
147: grey whale
|
167 |
+
148: killer whale
|
168 |
+
149: dugong
|
169 |
+
150: sea lion
|
170 |
+
151: Chihuahua
|
171 |
+
152: Japanese Chin
|
172 |
+
153: Maltese
|
173 |
+
154: Pekingese
|
174 |
+
155: Shih Tzu
|
175 |
+
156: King Charles Spaniel
|
176 |
+
157: Papillon
|
177 |
+
158: toy terrier
|
178 |
+
159: Rhodesian Ridgeback
|
179 |
+
160: Afghan Hound
|
180 |
+
161: Basset Hound
|
181 |
+
162: Beagle
|
182 |
+
163: Bloodhound
|
183 |
+
164: Bluetick Coonhound
|
184 |
+
165: Black and Tan Coonhound
|
185 |
+
166: Treeing Walker Coonhound
|
186 |
+
167: English foxhound
|
187 |
+
168: Redbone Coonhound
|
188 |
+
169: borzoi
|
189 |
+
170: Irish Wolfhound
|
190 |
+
171: Italian Greyhound
|
191 |
+
172: Whippet
|
192 |
+
173: Ibizan Hound
|
193 |
+
174: Norwegian Elkhound
|
194 |
+
175: Otterhound
|
195 |
+
176: Saluki
|
196 |
+
177: Scottish Deerhound
|
197 |
+
178: Weimaraner
|
198 |
+
179: Staffordshire Bull Terrier
|
199 |
+
180: American Staffordshire Terrier
|
200 |
+
181: Bedlington Terrier
|
201 |
+
182: Border Terrier
|
202 |
+
183: Kerry Blue Terrier
|
203 |
+
184: Irish Terrier
|
204 |
+
185: Norfolk Terrier
|
205 |
+
186: Norwich Terrier
|
206 |
+
187: Yorkshire Terrier
|
207 |
+
188: Wire Fox Terrier
|
208 |
+
189: Lakeland Terrier
|
209 |
+
190: Sealyham Terrier
|
210 |
+
191: Airedale Terrier
|
211 |
+
192: Cairn Terrier
|
212 |
+
193: Australian Terrier
|
213 |
+
194: Dandie Dinmont Terrier
|
214 |
+
195: Boston Terrier
|
215 |
+
196: Miniature Schnauzer
|
216 |
+
197: Giant Schnauzer
|
217 |
+
198: Standard Schnauzer
|
218 |
+
199: Scottish Terrier
|
219 |
+
200: Tibetan Terrier
|
220 |
+
201: Australian Silky Terrier
|
221 |
+
202: Soft-coated Wheaten Terrier
|
222 |
+
203: West Highland White Terrier
|
223 |
+
204: Lhasa Apso
|
224 |
+
205: Flat-Coated Retriever
|
225 |
+
206: Curly-coated Retriever
|
226 |
+
207: Golden Retriever
|
227 |
+
208: Labrador Retriever
|
228 |
+
209: Chesapeake Bay Retriever
|
229 |
+
210: German Shorthaired Pointer
|
230 |
+
211: Vizsla
|
231 |
+
212: English Setter
|
232 |
+
213: Irish Setter
|
233 |
+
214: Gordon Setter
|
234 |
+
215: Brittany
|
235 |
+
216: Clumber Spaniel
|
236 |
+
217: English Springer Spaniel
|
237 |
+
218: Welsh Springer Spaniel
|
238 |
+
219: Cocker Spaniels
|
239 |
+
220: Sussex Spaniel
|
240 |
+
221: Irish Water Spaniel
|
241 |
+
222: Kuvasz
|
242 |
+
223: Schipperke
|
243 |
+
224: Groenendael
|
244 |
+
225: Malinois
|
245 |
+
226: Briard
|
246 |
+
227: Australian Kelpie
|
247 |
+
228: Komondor
|
248 |
+
229: Old English Sheepdog
|
249 |
+
230: Shetland Sheepdog
|
250 |
+
231: collie
|
251 |
+
232: Border Collie
|
252 |
+
233: Bouvier des Flandres
|
253 |
+
234: Rottweiler
|
254 |
+
235: German Shepherd Dog
|
255 |
+
236: Dobermann
|
256 |
+
237: Miniature Pinscher
|
257 |
+
238: Greater Swiss Mountain Dog
|
258 |
+
239: Bernese Mountain Dog
|
259 |
+
240: Appenzeller Sennenhund
|
260 |
+
241: Entlebucher Sennenhund
|
261 |
+
242: Boxer
|
262 |
+
243: Bullmastiff
|
263 |
+
244: Tibetan Mastiff
|
264 |
+
245: French Bulldog
|
265 |
+
246: Great Dane
|
266 |
+
247: St. Bernard
|
267 |
+
248: husky
|
268 |
+
249: Alaskan Malamute
|
269 |
+
250: Siberian Husky
|
270 |
+
251: Dalmatian
|
271 |
+
252: Affenpinscher
|
272 |
+
253: Basenji
|
273 |
+
254: pug
|
274 |
+
255: Leonberger
|
275 |
+
256: Newfoundland
|
276 |
+
257: Pyrenean Mountain Dog
|
277 |
+
258: Samoyed
|
278 |
+
259: Pomeranian
|
279 |
+
260: Chow Chow
|
280 |
+
261: Keeshond
|
281 |
+
262: Griffon Bruxellois
|
282 |
+
263: Pembroke Welsh Corgi
|
283 |
+
264: Cardigan Welsh Corgi
|
284 |
+
265: Toy Poodle
|
285 |
+
266: Miniature Poodle
|
286 |
+
267: Standard Poodle
|
287 |
+
268: Mexican hairless dog
|
288 |
+
269: grey wolf
|
289 |
+
270: Alaskan tundra wolf
|
290 |
+
271: red wolf
|
291 |
+
272: coyote
|
292 |
+
273: dingo
|
293 |
+
274: dhole
|
294 |
+
275: African wild dog
|
295 |
+
276: hyena
|
296 |
+
277: red fox
|
297 |
+
278: kit fox
|
298 |
+
279: Arctic fox
|
299 |
+
280: grey fox
|
300 |
+
281: tabby cat
|
301 |
+
282: tiger cat
|
302 |
+
283: Persian cat
|
303 |
+
284: Siamese cat
|
304 |
+
285: Egyptian Mau
|
305 |
+
286: cougar
|
306 |
+
287: lynx
|
307 |
+
288: leopard
|
308 |
+
289: snow leopard
|
309 |
+
290: jaguar
|
310 |
+
291: lion
|
311 |
+
292: tiger
|
312 |
+
293: cheetah
|
313 |
+
294: brown bear
|
314 |
+
295: American black bear
|
315 |
+
296: polar bear
|
316 |
+
297: sloth bear
|
317 |
+
298: mongoose
|
318 |
+
299: meerkat
|
319 |
+
300: tiger beetle
|
320 |
+
301: ladybug
|
321 |
+
302: ground beetle
|
322 |
+
303: longhorn beetle
|
323 |
+
304: leaf beetle
|
324 |
+
305: dung beetle
|
325 |
+
306: rhinoceros beetle
|
326 |
+
307: weevil
|
327 |
+
308: fly
|
328 |
+
309: bee
|
329 |
+
310: ant
|
330 |
+
311: grasshopper
|
331 |
+
312: cricket
|
332 |
+
313: stick insect
|
333 |
+
314: cockroach
|
334 |
+
315: mantis
|
335 |
+
316: cicada
|
336 |
+
317: leafhopper
|
337 |
+
318: lacewing
|
338 |
+
319: dragonfly
|
339 |
+
320: damselfly
|
340 |
+
321: red admiral
|
341 |
+
322: ringlet
|
342 |
+
323: monarch butterfly
|
343 |
+
324: small white
|
344 |
+
325: sulphur butterfly
|
345 |
+
326: gossamer-winged butterfly
|
346 |
+
327: starfish
|
347 |
+
328: sea urchin
|
348 |
+
329: sea cucumber
|
349 |
+
330: cottontail rabbit
|
350 |
+
331: hare
|
351 |
+
332: Angora rabbit
|
352 |
+
333: hamster
|
353 |
+
334: porcupine
|
354 |
+
335: fox squirrel
|
355 |
+
336: marmot
|
356 |
+
337: beaver
|
357 |
+
338: guinea pig
|
358 |
+
339: common sorrel
|
359 |
+
340: zebra
|
360 |
+
341: pig
|
361 |
+
342: wild boar
|
362 |
+
343: warthog
|
363 |
+
344: hippopotamus
|
364 |
+
345: ox
|
365 |
+
346: water buffalo
|
366 |
+
347: bison
|
367 |
+
348: ram
|
368 |
+
349: bighorn sheep
|
369 |
+
350: Alpine ibex
|
370 |
+
351: hartebeest
|
371 |
+
352: impala
|
372 |
+
353: gazelle
|
373 |
+
354: dromedary
|
374 |
+
355: llama
|
375 |
+
356: weasel
|
376 |
+
357: mink
|
377 |
+
358: European polecat
|
378 |
+
359: black-footed ferret
|
379 |
+
360: otter
|
380 |
+
361: skunk
|
381 |
+
362: badger
|
382 |
+
363: armadillo
|
383 |
+
364: three-toed sloth
|
384 |
+
365: orangutan
|
385 |
+
366: gorilla
|
386 |
+
367: chimpanzee
|
387 |
+
368: gibbon
|
388 |
+
369: siamang
|
389 |
+
370: guenon
|
390 |
+
371: patas monkey
|
391 |
+
372: baboon
|
392 |
+
373: macaque
|
393 |
+
374: langur
|
394 |
+
375: black-and-white colobus
|
395 |
+
376: proboscis monkey
|
396 |
+
377: marmoset
|
397 |
+
378: white-headed capuchin
|
398 |
+
379: howler monkey
|
399 |
+
380: titi
|
400 |
+
381: Geoffroy's spider monkey
|
401 |
+
382: common squirrel monkey
|
402 |
+
383: ring-tailed lemur
|
403 |
+
384: indri
|
404 |
+
385: Asian elephant
|
405 |
+
386: African bush elephant
|
406 |
+
387: red panda
|
407 |
+
388: giant panda
|
408 |
+
389: snoek
|
409 |
+
390: eel
|
410 |
+
391: coho salmon
|
411 |
+
392: rock beauty
|
412 |
+
393: clownfish
|
413 |
+
394: sturgeon
|
414 |
+
395: garfish
|
415 |
+
396: lionfish
|
416 |
+
397: pufferfish
|
417 |
+
398: abacus
|
418 |
+
399: abaya
|
419 |
+
400: academic gown
|
420 |
+
401: accordion
|
421 |
+
402: acoustic guitar
|
422 |
+
403: aircraft carrier
|
423 |
+
404: airliner
|
424 |
+
405: airship
|
425 |
+
406: altar
|
426 |
+
407: ambulance
|
427 |
+
408: amphibious vehicle
|
428 |
+
409: analog clock
|
429 |
+
410: apiary
|
430 |
+
411: apron
|
431 |
+
412: waste container
|
432 |
+
413: assault rifle
|
433 |
+
414: backpack
|
434 |
+
415: bakery
|
435 |
+
416: balance beam
|
436 |
+
417: balloon
|
437 |
+
418: ballpoint pen
|
438 |
+
419: Band-Aid
|
439 |
+
420: banjo
|
440 |
+
421: baluster
|
441 |
+
422: barbell
|
442 |
+
423: barber chair
|
443 |
+
424: barbershop
|
444 |
+
425: barn
|
445 |
+
426: barometer
|
446 |
+
427: barrel
|
447 |
+
428: wheelbarrow
|
448 |
+
429: baseball
|
449 |
+
430: basketball
|
450 |
+
431: bassinet
|
451 |
+
432: bassoon
|
452 |
+
433: swimming cap
|
453 |
+
434: bath towel
|
454 |
+
435: bathtub
|
455 |
+
436: station wagon
|
456 |
+
437: lighthouse
|
457 |
+
438: beaker
|
458 |
+
439: military cap
|
459 |
+
440: beer bottle
|
460 |
+
441: beer glass
|
461 |
+
442: bell-cot
|
462 |
+
443: bib
|
463 |
+
444: tandem bicycle
|
464 |
+
445: bikini
|
465 |
+
446: ring binder
|
466 |
+
447: binoculars
|
467 |
+
448: birdhouse
|
468 |
+
449: boathouse
|
469 |
+
450: bobsleigh
|
470 |
+
451: bolo tie
|
471 |
+
452: poke bonnet
|
472 |
+
453: bookcase
|
473 |
+
454: bookstore
|
474 |
+
455: bottle cap
|
475 |
+
456: bow
|
476 |
+
457: bow tie
|
477 |
+
458: brass
|
478 |
+
459: bra
|
479 |
+
460: breakwater
|
480 |
+
461: breastplate
|
481 |
+
462: broom
|
482 |
+
463: bucket
|
483 |
+
464: buckle
|
484 |
+
465: bulletproof vest
|
485 |
+
466: high-speed train
|
486 |
+
467: butcher shop
|
487 |
+
468: taxicab
|
488 |
+
469: cauldron
|
489 |
+
470: candle
|
490 |
+
471: cannon
|
491 |
+
472: canoe
|
492 |
+
473: can opener
|
493 |
+
474: cardigan
|
494 |
+
475: car mirror
|
495 |
+
476: carousel
|
496 |
+
477: tool kit
|
497 |
+
478: carton
|
498 |
+
479: car wheel
|
499 |
+
480: automated teller machine
|
500 |
+
481: cassette
|
501 |
+
482: cassette player
|
502 |
+
483: castle
|
503 |
+
484: catamaran
|
504 |
+
485: CD player
|
505 |
+
486: cello
|
506 |
+
487: mobile phone
|
507 |
+
488: chain
|
508 |
+
489: chain-link fence
|
509 |
+
490: chain mail
|
510 |
+
491: chainsaw
|
511 |
+
492: chest
|
512 |
+
493: chiffonier
|
513 |
+
494: chime
|
514 |
+
495: china cabinet
|
515 |
+
496: Christmas stocking
|
516 |
+
497: church
|
517 |
+
498: movie theater
|
518 |
+
499: cleaver
|
519 |
+
500: cliff dwelling
|
520 |
+
501: cloak
|
521 |
+
502: clogs
|
522 |
+
503: cocktail shaker
|
523 |
+
504: coffee mug
|
524 |
+
505: coffeemaker
|
525 |
+
506: coil
|
526 |
+
507: combination lock
|
527 |
+
508: computer keyboard
|
528 |
+
509: confectionery store
|
529 |
+
510: container ship
|
530 |
+
511: convertible
|
531 |
+
512: corkscrew
|
532 |
+
513: cornet
|
533 |
+
514: cowboy boot
|
534 |
+
515: cowboy hat
|
535 |
+
516: cradle
|
536 |
+
517: crane (machine)
|
537 |
+
518: crash helmet
|
538 |
+
519: crate
|
539 |
+
520: infant bed
|
540 |
+
521: Crock Pot
|
541 |
+
522: croquet ball
|
542 |
+
523: crutch
|
543 |
+
524: cuirass
|
544 |
+
525: dam
|
545 |
+
526: desk
|
546 |
+
527: desktop computer
|
547 |
+
528: rotary dial telephone
|
548 |
+
529: diaper
|
549 |
+
530: digital clock
|
550 |
+
531: digital watch
|
551 |
+
532: dining table
|
552 |
+
533: dishcloth
|
553 |
+
534: dishwasher
|
554 |
+
535: disc brake
|
555 |
+
536: dock
|
556 |
+
537: dog sled
|
557 |
+
538: dome
|
558 |
+
539: doormat
|
559 |
+
540: drilling rig
|
560 |
+
541: drum
|
561 |
+
542: drumstick
|
562 |
+
543: dumbbell
|
563 |
+
544: Dutch oven
|
564 |
+
545: electric fan
|
565 |
+
546: electric guitar
|
566 |
+
547: electric locomotive
|
567 |
+
548: entertainment center
|
568 |
+
549: envelope
|
569 |
+
550: espresso machine
|
570 |
+
551: face powder
|
571 |
+
552: feather boa
|
572 |
+
553: filing cabinet
|
573 |
+
554: fireboat
|
574 |
+
555: fire engine
|
575 |
+
556: fire screen sheet
|
576 |
+
557: flagpole
|
577 |
+
558: flute
|
578 |
+
559: folding chair
|
579 |
+
560: football helmet
|
580 |
+
561: forklift
|
581 |
+
562: fountain
|
582 |
+
563: fountain pen
|
583 |
+
564: four-poster bed
|
584 |
+
565: freight car
|
585 |
+
566: French horn
|
586 |
+
567: frying pan
|
587 |
+
568: fur coat
|
588 |
+
569: garbage truck
|
589 |
+
570: gas mask
|
590 |
+
571: gas pump
|
591 |
+
572: goblet
|
592 |
+
573: go-kart
|
593 |
+
574: golf ball
|
594 |
+
575: golf cart
|
595 |
+
576: gondola
|
596 |
+
577: gong
|
597 |
+
578: gown
|
598 |
+
579: grand piano
|
599 |
+
580: greenhouse
|
600 |
+
581: grille
|
601 |
+
582: grocery store
|
602 |
+
583: guillotine
|
603 |
+
584: barrette
|
604 |
+
585: hair spray
|
605 |
+
586: half-track
|
606 |
+
587: hammer
|
607 |
+
588: hamper
|
608 |
+
589: hair dryer
|
609 |
+
590: hand-held computer
|
610 |
+
591: handkerchief
|
611 |
+
592: hard disk drive
|
612 |
+
593: harmonica
|
613 |
+
594: harp
|
614 |
+
595: harvester
|
615 |
+
596: hatchet
|
616 |
+
597: holster
|
617 |
+
598: home theater
|
618 |
+
599: honeycomb
|
619 |
+
600: hook
|
620 |
+
601: hoop skirt
|
621 |
+
602: horizontal bar
|
622 |
+
603: horse-drawn vehicle
|
623 |
+
604: hourglass
|
624 |
+
605: iPod
|
625 |
+
606: clothes iron
|
626 |
+
607: jack-o'-lantern
|
627 |
+
608: jeans
|
628 |
+
609: jeep
|
629 |
+
610: T-shirt
|
630 |
+
611: jigsaw puzzle
|
631 |
+
612: pulled rickshaw
|
632 |
+
613: joystick
|
633 |
+
614: kimono
|
634 |
+
615: knee pad
|
635 |
+
616: knot
|
636 |
+
617: lab coat
|
637 |
+
618: ladle
|
638 |
+
619: lampshade
|
639 |
+
620: laptop computer
|
640 |
+
621: lawn mower
|
641 |
+
622: lens cap
|
642 |
+
623: paper knife
|
643 |
+
624: library
|
644 |
+
625: lifeboat
|
645 |
+
626: lighter
|
646 |
+
627: limousine
|
647 |
+
628: ocean liner
|
648 |
+
629: lipstick
|
649 |
+
630: slip-on shoe
|
650 |
+
631: lotion
|
651 |
+
632: speaker
|
652 |
+
633: loupe
|
653 |
+
634: sawmill
|
654 |
+
635: magnetic compass
|
655 |
+
636: mail bag
|
656 |
+
637: mailbox
|
657 |
+
638: tights
|
658 |
+
639: tank suit
|
659 |
+
640: manhole cover
|
660 |
+
641: maraca
|
661 |
+
642: marimba
|
662 |
+
643: mask
|
663 |
+
644: match
|
664 |
+
645: maypole
|
665 |
+
646: maze
|
666 |
+
647: measuring cup
|
667 |
+
648: medicine chest
|
668 |
+
649: megalith
|
669 |
+
650: microphone
|
670 |
+
651: microwave oven
|
671 |
+
652: military uniform
|
672 |
+
653: milk can
|
673 |
+
654: minibus
|
674 |
+
655: miniskirt
|
675 |
+
656: minivan
|
676 |
+
657: missile
|
677 |
+
658: mitten
|
678 |
+
659: mixing bowl
|
679 |
+
660: mobile home
|
680 |
+
661: Model T
|
681 |
+
662: modem
|
682 |
+
663: monastery
|
683 |
+
664: monitor
|
684 |
+
665: moped
|
685 |
+
666: mortar
|
686 |
+
667: square academic cap
|
687 |
+
668: mosque
|
688 |
+
669: mosquito net
|
689 |
+
670: scooter
|
690 |
+
671: mountain bike
|
691 |
+
672: tent
|
692 |
+
673: computer mouse
|
693 |
+
674: mousetrap
|
694 |
+
675: moving van
|
695 |
+
676: muzzle
|
696 |
+
677: nail
|
697 |
+
678: neck brace
|
698 |
+
679: necklace
|
699 |
+
680: nipple
|
700 |
+
681: notebook computer
|
701 |
+
682: obelisk
|
702 |
+
683: oboe
|
703 |
+
684: ocarina
|
704 |
+
685: odometer
|
705 |
+
686: oil filter
|
706 |
+
687: organ
|
707 |
+
688: oscilloscope
|
708 |
+
689: overskirt
|
709 |
+
690: bullock cart
|
710 |
+
691: oxygen mask
|
711 |
+
692: packet
|
712 |
+
693: paddle
|
713 |
+
694: paddle wheel
|
714 |
+
695: padlock
|
715 |
+
696: paintbrush
|
716 |
+
697: pajamas
|
717 |
+
698: palace
|
718 |
+
699: pan flute
|
719 |
+
700: paper towel
|
720 |
+
701: parachute
|
721 |
+
702: parallel bars
|
722 |
+
703: park bench
|
723 |
+
704: parking meter
|
724 |
+
705: passenger car
|
725 |
+
706: patio
|
726 |
+
707: payphone
|
727 |
+
708: pedestal
|
728 |
+
709: pencil case
|
729 |
+
710: pencil sharpener
|
730 |
+
711: perfume
|
731 |
+
712: Petri dish
|
732 |
+
713: photocopier
|
733 |
+
714: plectrum
|
734 |
+
715: Pickelhaube
|
735 |
+
716: picket fence
|
736 |
+
717: pickup truck
|
737 |
+
718: pier
|
738 |
+
719: piggy bank
|
739 |
+
720: pill bottle
|
740 |
+
721: pillow
|
741 |
+
722: ping-pong ball
|
742 |
+
723: pinwheel
|
743 |
+
724: pirate ship
|
744 |
+
725: pitcher
|
745 |
+
726: hand plane
|
746 |
+
727: planetarium
|
747 |
+
728: plastic bag
|
748 |
+
729: plate rack
|
749 |
+
730: plow
|
750 |
+
731: plunger
|
751 |
+
732: Polaroid camera
|
752 |
+
733: pole
|
753 |
+
734: police van
|
754 |
+
735: poncho
|
755 |
+
736: billiard table
|
756 |
+
737: soda bottle
|
757 |
+
738: pot
|
758 |
+
739: potter's wheel
|
759 |
+
740: power drill
|
760 |
+
741: prayer rug
|
761 |
+
742: printer
|
762 |
+
743: prison
|
763 |
+
744: projectile
|
764 |
+
745: projector
|
765 |
+
746: hockey puck
|
766 |
+
747: punching bag
|
767 |
+
748: purse
|
768 |
+
749: quill
|
769 |
+
750: quilt
|
770 |
+
751: race car
|
771 |
+
752: racket
|
772 |
+
753: radiator
|
773 |
+
754: radio
|
774 |
+
755: radio telescope
|
775 |
+
756: rain barrel
|
776 |
+
757: recreational vehicle
|
777 |
+
758: reel
|
778 |
+
759: reflex camera
|
779 |
+
760: refrigerator
|
780 |
+
761: remote control
|
781 |
+
762: restaurant
|
782 |
+
763: revolver
|
783 |
+
764: rifle
|
784 |
+
765: rocking chair
|
785 |
+
766: rotisserie
|
786 |
+
767: eraser
|
787 |
+
768: rugby ball
|
788 |
+
769: ruler
|
789 |
+
770: running shoe
|
790 |
+
771: safe
|
791 |
+
772: safety pin
|
792 |
+
773: salt shaker
|
793 |
+
774: sandal
|
794 |
+
775: sarong
|
795 |
+
776: saxophone
|
796 |
+
777: scabbard
|
797 |
+
778: weighing scale
|
798 |
+
779: school bus
|
799 |
+
780: schooner
|
800 |
+
781: scoreboard
|
801 |
+
782: CRT screen
|
802 |
+
783: screw
|
803 |
+
784: screwdriver
|
804 |
+
785: seat belt
|
805 |
+
786: sewing machine
|
806 |
+
787: shield
|
807 |
+
788: shoe store
|
808 |
+
789: shoji
|
809 |
+
790: shopping basket
|
810 |
+
791: shopping cart
|
811 |
+
792: shovel
|
812 |
+
793: shower cap
|
813 |
+
794: shower curtain
|
814 |
+
795: ski
|
815 |
+
796: ski mask
|
816 |
+
797: sleeping bag
|
817 |
+
798: slide rule
|
818 |
+
799: sliding door
|
819 |
+
800: slot machine
|
820 |
+
801: snorkel
|
821 |
+
802: snowmobile
|
822 |
+
803: snowplow
|
823 |
+
804: soap dispenser
|
824 |
+
805: soccer ball
|
825 |
+
806: sock
|
826 |
+
807: solar thermal collector
|
827 |
+
808: sombrero
|
828 |
+
809: soup bowl
|
829 |
+
810: space bar
|
830 |
+
811: space heater
|
831 |
+
812: space shuttle
|
832 |
+
813: spatula
|
833 |
+
814: motorboat
|
834 |
+
815: spider web
|
835 |
+
816: spindle
|
836 |
+
817: sports car
|
837 |
+
818: spotlight
|
838 |
+
819: stage
|
839 |
+
820: steam locomotive
|
840 |
+
821: through arch bridge
|
841 |
+
822: steel drum
|
842 |
+
823: stethoscope
|
843 |
+
824: scarf
|
844 |
+
825: stone wall
|
845 |
+
826: stopwatch
|
846 |
+
827: stove
|
847 |
+
828: strainer
|
848 |
+
829: tram
|
849 |
+
830: stretcher
|
850 |
+
831: couch
|
851 |
+
832: stupa
|
852 |
+
833: submarine
|
853 |
+
834: suit
|
854 |
+
835: sundial
|
855 |
+
836: sunglass
|
856 |
+
837: sunglasses
|
857 |
+
838: sunscreen
|
858 |
+
839: suspension bridge
|
859 |
+
840: mop
|
860 |
+
841: sweatshirt
|
861 |
+
842: swimsuit
|
862 |
+
843: swing
|
863 |
+
844: switch
|
864 |
+
845: syringe
|
865 |
+
846: table lamp
|
866 |
+
847: tank
|
867 |
+
848: tape player
|
868 |
+
849: teapot
|
869 |
+
850: teddy bear
|
870 |
+
851: television
|
871 |
+
852: tennis ball
|
872 |
+
853: thatched roof
|
873 |
+
854: front curtain
|
874 |
+
855: thimble
|
875 |
+
856: threshing machine
|
876 |
+
857: throne
|
877 |
+
858: tile roof
|
878 |
+
859: toaster
|
879 |
+
860: tobacco shop
|
880 |
+
861: toilet seat
|
881 |
+
862: torch
|
882 |
+
863: totem pole
|
883 |
+
864: tow truck
|
884 |
+
865: toy store
|
885 |
+
866: tractor
|
886 |
+
867: semi-trailer truck
|
887 |
+
868: tray
|
888 |
+
869: trench coat
|
889 |
+
870: tricycle
|
890 |
+
871: trimaran
|
891 |
+
872: tripod
|
892 |
+
873: triumphal arch
|
893 |
+
874: trolleybus
|
894 |
+
875: trombone
|
895 |
+
876: tub
|
896 |
+
877: turnstile
|
897 |
+
878: typewriter keyboard
|
898 |
+
879: umbrella
|
899 |
+
880: unicycle
|
900 |
+
881: upright piano
|
901 |
+
882: vacuum cleaner
|
902 |
+
883: vase
|
903 |
+
884: vault
|
904 |
+
885: velvet
|
905 |
+
886: vending machine
|
906 |
+
887: vestment
|
907 |
+
888: viaduct
|
908 |
+
889: violin
|
909 |
+
890: volleyball
|
910 |
+
891: waffle iron
|
911 |
+
892: wall clock
|
912 |
+
893: wallet
|
913 |
+
894: wardrobe
|
914 |
+
895: military aircraft
|
915 |
+
896: sink
|
916 |
+
897: washing machine
|
917 |
+
898: water bottle
|
918 |
+
899: water jug
|
919 |
+
900: water tower
|
920 |
+
901: whiskey jug
|
921 |
+
902: whistle
|
922 |
+
903: wig
|
923 |
+
904: window screen
|
924 |
+
905: window shade
|
925 |
+
906: Windsor tie
|
926 |
+
907: wine bottle
|
927 |
+
908: wing
|
928 |
+
909: wok
|
929 |
+
910: wooden spoon
|
930 |
+
911: wool
|
931 |
+
912: split-rail fence
|
932 |
+
913: shipwreck
|
933 |
+
914: yawl
|
934 |
+
915: yurt
|
935 |
+
916: website
|
936 |
+
917: comic book
|
937 |
+
918: crossword
|
938 |
+
919: traffic sign
|
939 |
+
920: traffic light
|
940 |
+
921: dust jacket
|
941 |
+
922: menu
|
942 |
+
923: plate
|
943 |
+
924: guacamole
|
944 |
+
925: consomme
|
945 |
+
926: hot pot
|
946 |
+
927: trifle
|
947 |
+
928: ice cream
|
948 |
+
929: ice pop
|
949 |
+
930: baguette
|
950 |
+
931: bagel
|
951 |
+
932: pretzel
|
952 |
+
933: cheeseburger
|
953 |
+
934: hot dog
|
954 |
+
935: mashed potato
|
955 |
+
936: cabbage
|
956 |
+
937: broccoli
|
957 |
+
938: cauliflower
|
958 |
+
939: zucchini
|
959 |
+
940: spaghetti squash
|
960 |
+
941: acorn squash
|
961 |
+
942: butternut squash
|
962 |
+
943: cucumber
|
963 |
+
944: artichoke
|
964 |
+
945: bell pepper
|
965 |
+
946: cardoon
|
966 |
+
947: mushroom
|
967 |
+
948: Granny Smith
|
968 |
+
949: strawberry
|
969 |
+
950: orange
|
970 |
+
951: lemon
|
971 |
+
952: fig
|
972 |
+
953: pineapple
|
973 |
+
954: banana
|
974 |
+
955: jackfruit
|
975 |
+
956: custard apple
|
976 |
+
957: pomegranate
|
977 |
+
958: hay
|
978 |
+
959: carbonara
|
979 |
+
960: chocolate syrup
|
980 |
+
961: dough
|
981 |
+
962: meatloaf
|
982 |
+
963: pizza
|
983 |
+
964: pot pie
|
984 |
+
965: burrito
|
985 |
+
966: red wine
|
986 |
+
967: espresso
|
987 |
+
968: cup
|
988 |
+
969: eggnog
|
989 |
+
970: alp
|
990 |
+
971: bubble
|
991 |
+
972: cliff
|
992 |
+
973: coral reef
|
993 |
+
974: geyser
|
994 |
+
975: lakeshore
|
995 |
+
976: promontory
|
996 |
+
977: shoal
|
997 |
+
978: seashore
|
998 |
+
979: valley
|
999 |
+
980: volcano
|
1000 |
+
981: baseball player
|
1001 |
+
982: bridegroom
|
1002 |
+
983: scuba diver
|
1003 |
+
984: rapeseed
|
1004 |
+
985: daisy
|
1005 |
+
986: yellow lady's slipper
|
1006 |
+
987: corn
|
1007 |
+
988: acorn
|
1008 |
+
989: rose hip
|
1009 |
+
990: horse chestnut seed
|
1010 |
+
991: coral fungus
|
1011 |
+
992: agaric
|
1012 |
+
993: gyromitra
|
1013 |
+
994: stinkhorn mushroom
|
1014 |
+
995: earth star
|
1015 |
+
996: hen-of-the-woods
|
1016 |
+
997: bolete
|
1017 |
+
998: ear
|
1018 |
+
999: toilet paper
|
1019 |
+
|
1020 |
+
# Imagenet class codes to human-readable names
|
1021 |
+
map:
|
1022 |
+
n01440764: tench
|
1023 |
+
n01443537: goldfish
|
1024 |
+
n01484850: great_white_shark
|
1025 |
+
n01491361: tiger_shark
|
1026 |
+
n01494475: hammerhead
|
1027 |
+
n01496331: electric_ray
|
1028 |
+
n01498041: stingray
|
1029 |
+
n01514668: cock
|
1030 |
+
n01514859: hen
|
1031 |
+
n01518878: ostrich
|
1032 |
+
n01530575: brambling
|
1033 |
+
n01531178: goldfinch
|
1034 |
+
n01532829: house_finch
|
1035 |
+
n01534433: junco
|
1036 |
+
n01537544: indigo_bunting
|
1037 |
+
n01558993: robin
|
1038 |
+
n01560419: bulbul
|
1039 |
+
n01580077: jay
|
1040 |
+
n01582220: magpie
|
1041 |
+
n01592084: chickadee
|
1042 |
+
n01601694: water_ouzel
|
1043 |
+
n01608432: kite
|
1044 |
+
n01614925: bald_eagle
|
1045 |
+
n01616318: vulture
|
1046 |
+
n01622779: great_grey_owl
|
1047 |
+
n01629819: European_fire_salamander
|
1048 |
+
n01630670: common_newt
|
1049 |
+
n01631663: eft
|
1050 |
+
n01632458: spotted_salamander
|
1051 |
+
n01632777: axolotl
|
1052 |
+
n01641577: bullfrog
|
1053 |
+
n01644373: tree_frog
|
1054 |
+
n01644900: tailed_frog
|
1055 |
+
n01664065: loggerhead
|
1056 |
+
n01665541: leatherback_turtle
|
1057 |
+
n01667114: mud_turtle
|
1058 |
+
n01667778: terrapin
|
1059 |
+
n01669191: box_turtle
|
1060 |
+
n01675722: banded_gecko
|
1061 |
+
n01677366: common_iguana
|
1062 |
+
n01682714: American_chameleon
|
1063 |
+
n01685808: whiptail
|
1064 |
+
n01687978: agama
|
1065 |
+
n01688243: frilled_lizard
|
1066 |
+
n01689811: alligator_lizard
|
1067 |
+
n01692333: Gila_monster
|
1068 |
+
n01693334: green_lizard
|
1069 |
+
n01694178: African_chameleon
|
1070 |
+
n01695060: Komodo_dragon
|
1071 |
+
n01697457: African_crocodile
|
1072 |
+
n01698640: American_alligator
|
1073 |
+
n01704323: triceratops
|
1074 |
+
n01728572: thunder_snake
|
1075 |
+
n01728920: ringneck_snake
|
1076 |
+
n01729322: hognose_snake
|
1077 |
+
n01729977: green_snake
|
1078 |
+
n01734418: king_snake
|
1079 |
+
n01735189: garter_snake
|
1080 |
+
n01737021: water_snake
|
1081 |
+
n01739381: vine_snake
|
1082 |
+
n01740131: night_snake
|
1083 |
+
n01742172: boa_constrictor
|
1084 |
+
n01744401: rock_python
|
1085 |
+
n01748264: Indian_cobra
|
1086 |
+
n01749939: green_mamba
|
1087 |
+
n01751748: sea_snake
|
1088 |
+
n01753488: horned_viper
|
1089 |
+
n01755581: diamondback
|
1090 |
+
n01756291: sidewinder
|
1091 |
+
n01768244: trilobite
|
1092 |
+
n01770081: harvestman
|
1093 |
+
n01770393: scorpion
|
1094 |
+
n01773157: black_and_gold_garden_spider
|
1095 |
+
n01773549: barn_spider
|
1096 |
+
n01773797: garden_spider
|
1097 |
+
n01774384: black_widow
|
1098 |
+
n01774750: tarantula
|
1099 |
+
n01775062: wolf_spider
|
1100 |
+
n01776313: tick
|
1101 |
+
n01784675: centipede
|
1102 |
+
n01795545: black_grouse
|
1103 |
+
n01796340: ptarmigan
|
1104 |
+
n01797886: ruffed_grouse
|
1105 |
+
n01798484: prairie_chicken
|
1106 |
+
n01806143: peacock
|
1107 |
+
n01806567: quail
|
1108 |
+
n01807496: partridge
|
1109 |
+
n01817953: African_grey
|
1110 |
+
n01818515: macaw
|
1111 |
+
n01819313: sulphur-crested_cockatoo
|
1112 |
+
n01820546: lorikeet
|
1113 |
+
n01824575: coucal
|
1114 |
+
n01828970: bee_eater
|
1115 |
+
n01829413: hornbill
|
1116 |
+
n01833805: hummingbird
|
1117 |
+
n01843065: jacamar
|
1118 |
+
n01843383: toucan
|
1119 |
+
n01847000: drake
|
1120 |
+
n01855032: red-breasted_merganser
|
1121 |
+
n01855672: goose
|
1122 |
+
n01860187: black_swan
|
1123 |
+
n01871265: tusker
|
1124 |
+
n01872401: echidna
|
1125 |
+
n01873310: platypus
|
1126 |
+
n01877812: wallaby
|
1127 |
+
n01882714: koala
|
1128 |
+
n01883070: wombat
|
1129 |
+
n01910747: jellyfish
|
1130 |
+
n01914609: sea_anemone
|
1131 |
+
n01917289: brain_coral
|
1132 |
+
n01924916: flatworm
|
1133 |
+
n01930112: nematode
|
1134 |
+
n01943899: conch
|
1135 |
+
n01944390: snail
|
1136 |
+
n01945685: slug
|
1137 |
+
n01950731: sea_slug
|
1138 |
+
n01955084: chiton
|
1139 |
+
n01968897: chambered_nautilus
|
1140 |
+
n01978287: Dungeness_crab
|
1141 |
+
n01978455: rock_crab
|
1142 |
+
n01980166: fiddler_crab
|
1143 |
+
n01981276: king_crab
|
1144 |
+
n01983481: American_lobster
|
1145 |
+
n01984695: spiny_lobster
|
1146 |
+
n01985128: crayfish
|
1147 |
+
n01986214: hermit_crab
|
1148 |
+
n01990800: isopod
|
1149 |
+
n02002556: white_stork
|
1150 |
+
n02002724: black_stork
|
1151 |
+
n02006656: spoonbill
|
1152 |
+
n02007558: flamingo
|
1153 |
+
n02009229: little_blue_heron
|
1154 |
+
n02009912: American_egret
|
1155 |
+
n02011460: bittern
|
1156 |
+
n02012849: crane_(bird)
|
1157 |
+
n02013706: limpkin
|
1158 |
+
n02017213: European_gallinule
|
1159 |
+
n02018207: American_coot
|
1160 |
+
n02018795: bustard
|
1161 |
+
n02025239: ruddy_turnstone
|
1162 |
+
n02027492: red-backed_sandpiper
|
1163 |
+
n02028035: redshank
|
1164 |
+
n02033041: dowitcher
|
1165 |
+
n02037110: oystercatcher
|
1166 |
+
n02051845: pelican
|
1167 |
+
n02056570: king_penguin
|
1168 |
+
n02058221: albatross
|
1169 |
+
n02066245: grey_whale
|
1170 |
+
n02071294: killer_whale
|
1171 |
+
n02074367: dugong
|
1172 |
+
n02077923: sea_lion
|
1173 |
+
n02085620: Chihuahua
|
1174 |
+
n02085782: Japanese_spaniel
|
1175 |
+
n02085936: Maltese_dog
|
1176 |
+
n02086079: Pekinese
|
1177 |
+
n02086240: Shih-Tzu
|
1178 |
+
n02086646: Blenheim_spaniel
|
1179 |
+
n02086910: papillon
|
1180 |
+
n02087046: toy_terrier
|
1181 |
+
n02087394: Rhodesian_ridgeback
|
1182 |
+
n02088094: Afghan_hound
|
1183 |
+
n02088238: basset
|
1184 |
+
n02088364: beagle
|
1185 |
+
n02088466: bloodhound
|
1186 |
+
n02088632: bluetick
|
1187 |
+
n02089078: black-and-tan_coonhound
|
1188 |
+
n02089867: Walker_hound
|
1189 |
+
n02089973: English_foxhound
|
1190 |
+
n02090379: redbone
|
1191 |
+
n02090622: borzoi
|
1192 |
+
n02090721: Irish_wolfhound
|
1193 |
+
n02091032: Italian_greyhound
|
1194 |
+
n02091134: whippet
|
1195 |
+
n02091244: Ibizan_hound
|
1196 |
+
n02091467: Norwegian_elkhound
|
1197 |
+
n02091635: otterhound
|
1198 |
+
n02091831: Saluki
|
1199 |
+
n02092002: Scottish_deerhound
|
1200 |
+
n02092339: Weimaraner
|
1201 |
+
n02093256: Staffordshire_bullterrier
|
1202 |
+
n02093428: American_Staffordshire_terrier
|
1203 |
+
n02093647: Bedlington_terrier
|
1204 |
+
n02093754: Border_terrier
|
1205 |
+
n02093859: Kerry_blue_terrier
|
1206 |
+
n02093991: Irish_terrier
|
1207 |
+
n02094114: Norfolk_terrier
|
1208 |
+
n02094258: Norwich_terrier
|
1209 |
+
n02094433: Yorkshire_terrier
|
1210 |
+
n02095314: wire-haired_fox_terrier
|
1211 |
+
n02095570: Lakeland_terrier
|
1212 |
+
n02095889: Sealyham_terrier
|
1213 |
+
n02096051: Airedale
|
1214 |
+
n02096177: cairn
|
1215 |
+
n02096294: Australian_terrier
|
1216 |
+
n02096437: Dandie_Dinmont
|
1217 |
+
n02096585: Boston_bull
|
1218 |
+
n02097047: miniature_schnauzer
|
1219 |
+
n02097130: giant_schnauzer
|
1220 |
+
n02097209: standard_schnauzer
|
1221 |
+
n02097298: Scotch_terrier
|
1222 |
+
n02097474: Tibetan_terrier
|
1223 |
+
n02097658: silky_terrier
|
1224 |
+
n02098105: soft-coated_wheaten_terrier
|
1225 |
+
n02098286: West_Highland_white_terrier
|
1226 |
+
n02098413: Lhasa
|
1227 |
+
n02099267: flat-coated_retriever
|
1228 |
+
n02099429: curly-coated_retriever
|
1229 |
+
n02099601: golden_retriever
|
1230 |
+
n02099712: Labrador_retriever
|
1231 |
+
n02099849: Chesapeake_Bay_retriever
|
1232 |
+
n02100236: German_short-haired_pointer
|
1233 |
+
n02100583: vizsla
|
1234 |
+
n02100735: English_setter
|
1235 |
+
n02100877: Irish_setter
|
1236 |
+
n02101006: Gordon_setter
|
1237 |
+
n02101388: Brittany_spaniel
|
1238 |
+
n02101556: clumber
|
1239 |
+
n02102040: English_springer
|
1240 |
+
n02102177: Welsh_springer_spaniel
|
1241 |
+
n02102318: cocker_spaniel
|
1242 |
+
n02102480: Sussex_spaniel
|
1243 |
+
n02102973: Irish_water_spaniel
|
1244 |
+
n02104029: kuvasz
|
1245 |
+
n02104365: schipperke
|
1246 |
+
n02105056: groenendael
|
1247 |
+
n02105162: malinois
|
1248 |
+
n02105251: briard
|
1249 |
+
n02105412: kelpie
|
1250 |
+
n02105505: komondor
|
1251 |
+
n02105641: Old_English_sheepdog
|
1252 |
+
n02105855: Shetland_sheepdog
|
1253 |
+
n02106030: collie
|
1254 |
+
n02106166: Border_collie
|
1255 |
+
n02106382: Bouvier_des_Flandres
|
1256 |
+
n02106550: Rottweiler
|
1257 |
+
n02106662: German_shepherd
|
1258 |
+
n02107142: Doberman
|
1259 |
+
n02107312: miniature_pinscher
|
1260 |
+
n02107574: Greater_Swiss_Mountain_dog
|
1261 |
+
n02107683: Bernese_mountain_dog
|
1262 |
+
n02107908: Appenzeller
|
1263 |
+
n02108000: EntleBucher
|
1264 |
+
n02108089: boxer
|
1265 |
+
n02108422: bull_mastiff
|
1266 |
+
n02108551: Tibetan_mastiff
|
1267 |
+
n02108915: French_bulldog
|
1268 |
+
n02109047: Great_Dane
|
1269 |
+
n02109525: Saint_Bernard
|
1270 |
+
n02109961: Eskimo_dog
|
1271 |
+
n02110063: malamute
|
1272 |
+
n02110185: Siberian_husky
|
1273 |
+
n02110341: dalmatian
|
1274 |
+
n02110627: affenpinscher
|
1275 |
+
n02110806: basenji
|
1276 |
+
n02110958: pug
|
1277 |
+
n02111129: Leonberg
|
1278 |
+
n02111277: Newfoundland
|
1279 |
+
n02111500: Great_Pyrenees
|
1280 |
+
n02111889: Samoyed
|
1281 |
+
n02112018: Pomeranian
|
1282 |
+
n02112137: chow
|
1283 |
+
n02112350: keeshond
|
1284 |
+
n02112706: Brabancon_griffon
|
1285 |
+
n02113023: Pembroke
|
1286 |
+
n02113186: Cardigan
|
1287 |
+
n02113624: toy_poodle
|
1288 |
+
n02113712: miniature_poodle
|
1289 |
+
n02113799: standard_poodle
|
1290 |
+
n02113978: Mexican_hairless
|
1291 |
+
n02114367: timber_wolf
|
1292 |
+
n02114548: white_wolf
|
1293 |
+
n02114712: red_wolf
|
1294 |
+
n02114855: coyote
|
1295 |
+
n02115641: dingo
|
1296 |
+
n02115913: dhole
|
1297 |
+
n02116738: African_hunting_dog
|
1298 |
+
n02117135: hyena
|
1299 |
+
n02119022: red_fox
|
1300 |
+
n02119789: kit_fox
|
1301 |
+
n02120079: Arctic_fox
|
1302 |
+
n02120505: grey_fox
|
1303 |
+
n02123045: tabby
|
1304 |
+
n02123159: tiger_cat
|
1305 |
+
n02123394: Persian_cat
|
1306 |
+
n02123597: Siamese_cat
|
1307 |
+
n02124075: Egyptian_cat
|
1308 |
+
n02125311: cougar
|
1309 |
+
n02127052: lynx
|
1310 |
+
n02128385: leopard
|
1311 |
+
n02128757: snow_leopard
|
1312 |
+
n02128925: jaguar
|
1313 |
+
n02129165: lion
|
1314 |
+
n02129604: tiger
|
1315 |
+
n02130308: cheetah
|
1316 |
+
n02132136: brown_bear
|
1317 |
+
n02133161: American_black_bear
|
1318 |
+
n02134084: ice_bear
|
1319 |
+
n02134418: sloth_bear
|
1320 |
+
n02137549: mongoose
|
1321 |
+
n02138441: meerkat
|
1322 |
+
n02165105: tiger_beetle
|
1323 |
+
n02165456: ladybug
|
1324 |
+
n02167151: ground_beetle
|
1325 |
+
n02168699: long-horned_beetle
|
1326 |
+
n02169497: leaf_beetle
|
1327 |
+
n02172182: dung_beetle
|
1328 |
+
n02174001: rhinoceros_beetle
|
1329 |
+
n02177972: weevil
|
1330 |
+
n02190166: fly
|
1331 |
+
n02206856: bee
|
1332 |
+
n02219486: ant
|
1333 |
+
n02226429: grasshopper
|
1334 |
+
n02229544: cricket
|
1335 |
+
n02231487: walking_stick
|
1336 |
+
n02233338: cockroach
|
1337 |
+
n02236044: mantis
|
1338 |
+
n02256656: cicada
|
1339 |
+
n02259212: leafhopper
|
1340 |
+
n02264363: lacewing
|
1341 |
+
n02268443: dragonfly
|
1342 |
+
n02268853: damselfly
|
1343 |
+
n02276258: admiral
|
1344 |
+
n02277742: ringlet
|
1345 |
+
n02279972: monarch
|
1346 |
+
n02280649: cabbage_butterfly
|
1347 |
+
n02281406: sulphur_butterfly
|
1348 |
+
n02281787: lycaenid
|
1349 |
+
n02317335: starfish
|
1350 |
+
n02319095: sea_urchin
|
1351 |
+
n02321529: sea_cucumber
|
1352 |
+
n02325366: wood_rabbit
|
1353 |
+
n02326432: hare
|
1354 |
+
n02328150: Angora
|
1355 |
+
n02342885: hamster
|
1356 |
+
n02346627: porcupine
|
1357 |
+
n02356798: fox_squirrel
|
1358 |
+
n02361337: marmot
|
1359 |
+
n02363005: beaver
|
1360 |
+
n02364673: guinea_pig
|
1361 |
+
n02389026: sorrel
|
1362 |
+
n02391049: zebra
|
1363 |
+
n02395406: hog
|
1364 |
+
n02396427: wild_boar
|
1365 |
+
n02397096: warthog
|
1366 |
+
n02398521: hippopotamus
|
1367 |
+
n02403003: ox
|
1368 |
+
n02408429: water_buffalo
|
1369 |
+
n02410509: bison
|
1370 |
+
n02412080: ram
|
1371 |
+
n02415577: bighorn
|
1372 |
+
n02417914: ibex
|
1373 |
+
n02422106: hartebeest
|
1374 |
+
n02422699: impala
|
1375 |
+
n02423022: gazelle
|
1376 |
+
n02437312: Arabian_camel
|
1377 |
+
n02437616: llama
|
1378 |
+
n02441942: weasel
|
1379 |
+
n02442845: mink
|
1380 |
+
n02443114: polecat
|
1381 |
+
n02443484: black-footed_ferret
|
1382 |
+
n02444819: otter
|
1383 |
+
n02445715: skunk
|
1384 |
+
n02447366: badger
|
1385 |
+
n02454379: armadillo
|
1386 |
+
n02457408: three-toed_sloth
|
1387 |
+
n02480495: orangutan
|
1388 |
+
n02480855: gorilla
|
1389 |
+
n02481823: chimpanzee
|
1390 |
+
n02483362: gibbon
|
1391 |
+
n02483708: siamang
|
1392 |
+
n02484975: guenon
|
1393 |
+
n02486261: patas
|
1394 |
+
n02486410: baboon
|
1395 |
+
n02487347: macaque
|
1396 |
+
n02488291: langur
|
1397 |
+
n02488702: colobus
|
1398 |
+
n02489166: proboscis_monkey
|
1399 |
+
n02490219: marmoset
|
1400 |
+
n02492035: capuchin
|
1401 |
+
n02492660: howler_monkey
|
1402 |
+
n02493509: titi
|
1403 |
+
n02493793: spider_monkey
|
1404 |
+
n02494079: squirrel_monkey
|
1405 |
+
n02497673: Madagascar_cat
|
1406 |
+
n02500267: indri
|
1407 |
+
n02504013: Indian_elephant
|
1408 |
+
n02504458: African_elephant
|
1409 |
+
n02509815: lesser_panda
|
1410 |
+
n02510455: giant_panda
|
1411 |
+
n02514041: barracouta
|
1412 |
+
n02526121: eel
|
1413 |
+
n02536864: coho
|
1414 |
+
n02606052: rock_beauty
|
1415 |
+
n02607072: anemone_fish
|
1416 |
+
n02640242: sturgeon
|
1417 |
+
n02641379: gar
|
1418 |
+
n02643566: lionfish
|
1419 |
+
n02655020: puffer
|
1420 |
+
n02666196: abacus
|
1421 |
+
n02667093: abaya
|
1422 |
+
n02669723: academic_gown
|
1423 |
+
n02672831: accordion
|
1424 |
+
n02676566: acoustic_guitar
|
1425 |
+
n02687172: aircraft_carrier
|
1426 |
+
n02690373: airliner
|
1427 |
+
n02692877: airship
|
1428 |
+
n02699494: altar
|
1429 |
+
n02701002: ambulance
|
1430 |
+
n02704792: amphibian
|
1431 |
+
n02708093: analog_clock
|
1432 |
+
n02727426: apiary
|
1433 |
+
n02730930: apron
|
1434 |
+
n02747177: ashcan
|
1435 |
+
n02749479: assault_rifle
|
1436 |
+
n02769748: backpack
|
1437 |
+
n02776631: bakery
|
1438 |
+
n02777292: balance_beam
|
1439 |
+
n02782093: balloon
|
1440 |
+
n02783161: ballpoint
|
1441 |
+
n02786058: Band_Aid
|
1442 |
+
n02787622: banjo
|
1443 |
+
n02788148: bannister
|
1444 |
+
n02790996: barbell
|
1445 |
+
n02791124: barber_chair
|
1446 |
+
n02791270: barbershop
|
1447 |
+
n02793495: barn
|
1448 |
+
n02794156: barometer
|
1449 |
+
n02795169: barrel
|
1450 |
+
n02797295: barrow
|
1451 |
+
n02799071: baseball
|
1452 |
+
n02802426: basketball
|
1453 |
+
n02804414: bassinet
|
1454 |
+
n02804610: bassoon
|
1455 |
+
n02807133: bathing_cap
|
1456 |
+
n02808304: bath_towel
|
1457 |
+
n02808440: bathtub
|
1458 |
+
n02814533: beach_wagon
|
1459 |
+
n02814860: beacon
|
1460 |
+
n02815834: beaker
|
1461 |
+
n02817516: bearskin
|
1462 |
+
n02823428: beer_bottle
|
1463 |
+
n02823750: beer_glass
|
1464 |
+
n02825657: bell_cote
|
1465 |
+
n02834397: bib
|
1466 |
+
n02835271: bicycle-built-for-two
|
1467 |
+
n02837789: bikini
|
1468 |
+
n02840245: binder
|
1469 |
+
n02841315: binoculars
|
1470 |
+
n02843684: birdhouse
|
1471 |
+
n02859443: boathouse
|
1472 |
+
n02860847: bobsled
|
1473 |
+
n02865351: bolo_tie
|
1474 |
+
n02869837: bonnet
|
1475 |
+
n02870880: bookcase
|
1476 |
+
n02871525: bookshop
|
1477 |
+
n02877765: bottlecap
|
1478 |
+
n02879718: bow
|
1479 |
+
n02883205: bow_tie
|
1480 |
+
n02892201: brass
|
1481 |
+
n02892767: brassiere
|
1482 |
+
n02894605: breakwater
|
1483 |
+
n02895154: breastplate
|
1484 |
+
n02906734: broom
|
1485 |
+
n02909870: bucket
|
1486 |
+
n02910353: buckle
|
1487 |
+
n02916936: bulletproof_vest
|
1488 |
+
n02917067: bullet_train
|
1489 |
+
n02927161: butcher_shop
|
1490 |
+
n02930766: cab
|
1491 |
+
n02939185: caldron
|
1492 |
+
n02948072: candle
|
1493 |
+
n02950826: cannon
|
1494 |
+
n02951358: canoe
|
1495 |
+
n02951585: can_opener
|
1496 |
+
n02963159: cardigan
|
1497 |
+
n02965783: car_mirror
|
1498 |
+
n02966193: carousel
|
1499 |
+
n02966687: carpenter's_kit
|
1500 |
+
n02971356: carton
|
1501 |
+
n02974003: car_wheel
|
1502 |
+
n02977058: cash_machine
|
1503 |
+
n02978881: cassette
|
1504 |
+
n02979186: cassette_player
|
1505 |
+
n02980441: castle
|
1506 |
+
n02981792: catamaran
|
1507 |
+
n02988304: CD_player
|
1508 |
+
n02992211: cello
|
1509 |
+
n02992529: cellular_telephone
|
1510 |
+
n02999410: chain
|
1511 |
+
n03000134: chainlink_fence
|
1512 |
+
n03000247: chain_mail
|
1513 |
+
n03000684: chain_saw
|
1514 |
+
n03014705: chest
|
1515 |
+
n03016953: chiffonier
|
1516 |
+
n03017168: chime
|
1517 |
+
n03018349: china_cabinet
|
1518 |
+
n03026506: Christmas_stocking
|
1519 |
+
n03028079: church
|
1520 |
+
n03032252: cinema
|
1521 |
+
n03041632: cleaver
|
1522 |
+
n03042490: cliff_dwelling
|
1523 |
+
n03045698: cloak
|
1524 |
+
n03047690: clog
|
1525 |
+
n03062245: cocktail_shaker
|
1526 |
+
n03063599: coffee_mug
|
1527 |
+
n03063689: coffeepot
|
1528 |
+
n03065424: coil
|
1529 |
+
n03075370: combination_lock
|
1530 |
+
n03085013: computer_keyboard
|
1531 |
+
n03089624: confectionery
|
1532 |
+
n03095699: container_ship
|
1533 |
+
n03100240: convertible
|
1534 |
+
n03109150: corkscrew
|
1535 |
+
n03110669: cornet
|
1536 |
+
n03124043: cowboy_boot
|
1537 |
+
n03124170: cowboy_hat
|
1538 |
+
n03125729: cradle
|
1539 |
+
n03126707: crane_(machine)
|
1540 |
+
n03127747: crash_helmet
|
1541 |
+
n03127925: crate
|
1542 |
+
n03131574: crib
|
1543 |
+
n03133878: Crock_Pot
|
1544 |
+
n03134739: croquet_ball
|
1545 |
+
n03141823: crutch
|
1546 |
+
n03146219: cuirass
|
1547 |
+
n03160309: dam
|
1548 |
+
n03179701: desk
|
1549 |
+
n03180011: desktop_computer
|
1550 |
+
n03187595: dial_telephone
|
1551 |
+
n03188531: diaper
|
1552 |
+
n03196217: digital_clock
|
1553 |
+
n03197337: digital_watch
|
1554 |
+
n03201208: dining_table
|
1555 |
+
n03207743: dishrag
|
1556 |
+
n03207941: dishwasher
|
1557 |
+
n03208938: disk_brake
|
1558 |
+
n03216828: dock
|
1559 |
+
n03218198: dogsled
|
1560 |
+
n03220513: dome
|
1561 |
+
n03223299: doormat
|
1562 |
+
n03240683: drilling_platform
|
1563 |
+
n03249569: drum
|
1564 |
+
n03250847: drumstick
|
1565 |
+
n03255030: dumbbell
|
1566 |
+
n03259280: Dutch_oven
|
1567 |
+
n03271574: electric_fan
|
1568 |
+
n03272010: electric_guitar
|
1569 |
+
n03272562: electric_locomotive
|
1570 |
+
n03290653: entertainment_center
|
1571 |
+
n03291819: envelope
|
1572 |
+
n03297495: espresso_maker
|
1573 |
+
n03314780: face_powder
|
1574 |
+
n03325584: feather_boa
|
1575 |
+
n03337140: file
|
1576 |
+
n03344393: fireboat
|
1577 |
+
n03345487: fire_engine
|
1578 |
+
n03347037: fire_screen
|
1579 |
+
n03355925: flagpole
|
1580 |
+
n03372029: flute
|
1581 |
+
n03376595: folding_chair
|
1582 |
+
n03379051: football_helmet
|
1583 |
+
n03384352: forklift
|
1584 |
+
n03388043: fountain
|
1585 |
+
n03388183: fountain_pen
|
1586 |
+
n03388549: four-poster
|
1587 |
+
n03393912: freight_car
|
1588 |
+
n03394916: French_horn
|
1589 |
+
n03400231: frying_pan
|
1590 |
+
n03404251: fur_coat
|
1591 |
+
n03417042: garbage_truck
|
1592 |
+
n03424325: gasmask
|
1593 |
+
n03425413: gas_pump
|
1594 |
+
n03443371: goblet
|
1595 |
+
n03444034: go-kart
|
1596 |
+
n03445777: golf_ball
|
1597 |
+
n03445924: golfcart
|
1598 |
+
n03447447: gondola
|
1599 |
+
n03447721: gong
|
1600 |
+
n03450230: gown
|
1601 |
+
n03452741: grand_piano
|
1602 |
+
n03457902: greenhouse
|
1603 |
+
n03459775: grille
|
1604 |
+
n03461385: grocery_store
|
1605 |
+
n03467068: guillotine
|
1606 |
+
n03476684: hair_slide
|
1607 |
+
n03476991: hair_spray
|
1608 |
+
n03478589: half_track
|
1609 |
+
n03481172: hammer
|
1610 |
+
n03482405: hamper
|
1611 |
+
n03483316: hand_blower
|
1612 |
+
n03485407: hand-held_computer
|
1613 |
+
n03485794: handkerchief
|
1614 |
+
n03492542: hard_disc
|
1615 |
+
n03494278: harmonica
|
1616 |
+
n03495258: harp
|
1617 |
+
n03496892: harvester
|
1618 |
+
n03498962: hatchet
|
1619 |
+
n03527444: holster
|
1620 |
+
n03529860: home_theater
|
1621 |
+
n03530642: honeycomb
|
1622 |
+
n03532672: hook
|
1623 |
+
n03534580: hoopskirt
|
1624 |
+
n03535780: horizontal_bar
|
1625 |
+
n03538406: horse_cart
|
1626 |
+
n03544143: hourglass
|
1627 |
+
n03584254: iPod
|
1628 |
+
n03584829: iron
|
1629 |
+
n03590841: jack-o'-lantern
|
1630 |
+
n03594734: jean
|
1631 |
+
n03594945: jeep
|
1632 |
+
n03595614: jersey
|
1633 |
+
n03598930: jigsaw_puzzle
|
1634 |
+
n03599486: jinrikisha
|
1635 |
+
n03602883: joystick
|
1636 |
+
n03617480: kimono
|
1637 |
+
n03623198: knee_pad
|
1638 |
+
n03627232: knot
|
1639 |
+
n03630383: lab_coat
|
1640 |
+
n03633091: ladle
|
1641 |
+
n03637318: lampshade
|
1642 |
+
n03642806: laptop
|
1643 |
+
n03649909: lawn_mower
|
1644 |
+
n03657121: lens_cap
|
1645 |
+
n03658185: letter_opener
|
1646 |
+
n03661043: library
|
1647 |
+
n03662601: lifeboat
|
1648 |
+
n03666591: lighter
|
1649 |
+
n03670208: limousine
|
1650 |
+
n03673027: liner
|
1651 |
+
n03676483: lipstick
|
1652 |
+
n03680355: Loafer
|
1653 |
+
n03690938: lotion
|
1654 |
+
n03691459: loudspeaker
|
1655 |
+
n03692522: loupe
|
1656 |
+
n03697007: lumbermill
|
1657 |
+
n03706229: magnetic_compass
|
1658 |
+
n03709823: mailbag
|
1659 |
+
n03710193: mailbox
|
1660 |
+
n03710637: maillot_(tights)
|
1661 |
+
n03710721: maillot_(tank_suit)
|
1662 |
+
n03717622: manhole_cover
|
1663 |
+
n03720891: maraca
|
1664 |
+
n03721384: marimba
|
1665 |
+
n03724870: mask
|
1666 |
+
n03729826: matchstick
|
1667 |
+
n03733131: maypole
|
1668 |
+
n03733281: maze
|
1669 |
+
n03733805: measuring_cup
|
1670 |
+
n03742115: medicine_chest
|
1671 |
+
n03743016: megalith
|
1672 |
+
n03759954: microphone
|
1673 |
+
n03761084: microwave
|
1674 |
+
n03763968: military_uniform
|
1675 |
+
n03764736: milk_can
|
1676 |
+
n03769881: minibus
|
1677 |
+
n03770439: miniskirt
|
1678 |
+
n03770679: minivan
|
1679 |
+
n03773504: missile
|
1680 |
+
n03775071: mitten
|
1681 |
+
n03775546: mixing_bowl
|
1682 |
+
n03776460: mobile_home
|
1683 |
+
n03777568: Model_T
|
1684 |
+
n03777754: modem
|
1685 |
+
n03781244: monastery
|
1686 |
+
n03782006: monitor
|
1687 |
+
n03785016: moped
|
1688 |
+
n03786901: mortar
|
1689 |
+
n03787032: mortarboard
|
1690 |
+
n03788195: mosque
|
1691 |
+
n03788365: mosquito_net
|
1692 |
+
n03791053: motor_scooter
|
1693 |
+
n03792782: mountain_bike
|
1694 |
+
n03792972: mountain_tent
|
1695 |
+
n03793489: mouse
|
1696 |
+
n03794056: mousetrap
|
1697 |
+
n03796401: moving_van
|
1698 |
+
n03803284: muzzle
|
1699 |
+
n03804744: nail
|
1700 |
+
n03814639: neck_brace
|
1701 |
+
n03814906: necklace
|
1702 |
+
n03825788: nipple
|
1703 |
+
n03832673: notebook
|
1704 |
+
n03837869: obelisk
|
1705 |
+
n03838899: oboe
|
1706 |
+
n03840681: ocarina
|
1707 |
+
n03841143: odometer
|
1708 |
+
n03843555: oil_filter
|
1709 |
+
n03854065: organ
|
1710 |
+
n03857828: oscilloscope
|
1711 |
+
n03866082: overskirt
|
1712 |
+
n03868242: oxcart
|
1713 |
+
n03868863: oxygen_mask
|
1714 |
+
n03871628: packet
|
1715 |
+
n03873416: paddle
|
1716 |
+
n03874293: paddlewheel
|
1717 |
+
n03874599: padlock
|
1718 |
+
n03876231: paintbrush
|
1719 |
+
n03877472: pajama
|
1720 |
+
n03877845: palace
|
1721 |
+
n03884397: panpipe
|
1722 |
+
n03887697: paper_towel
|
1723 |
+
n03888257: parachute
|
1724 |
+
n03888605: parallel_bars
|
1725 |
+
n03891251: park_bench
|
1726 |
+
n03891332: parking_meter
|
1727 |
+
n03895866: passenger_car
|
1728 |
+
n03899768: patio
|
1729 |
+
n03902125: pay-phone
|
1730 |
+
n03903868: pedestal
|
1731 |
+
n03908618: pencil_box
|
1732 |
+
n03908714: pencil_sharpener
|
1733 |
+
n03916031: perfume
|
1734 |
+
n03920288: Petri_dish
|
1735 |
+
n03924679: photocopier
|
1736 |
+
n03929660: pick
|
1737 |
+
n03929855: pickelhaube
|
1738 |
+
n03930313: picket_fence
|
1739 |
+
n03930630: pickup
|
1740 |
+
n03933933: pier
|
1741 |
+
n03935335: piggy_bank
|
1742 |
+
n03937543: pill_bottle
|
1743 |
+
n03938244: pillow
|
1744 |
+
n03942813: ping-pong_ball
|
1745 |
+
n03944341: pinwheel
|
1746 |
+
n03947888: pirate
|
1747 |
+
n03950228: pitcher
|
1748 |
+
n03954731: plane
|
1749 |
+
n03956157: planetarium
|
1750 |
+
n03958227: plastic_bag
|
1751 |
+
n03961711: plate_rack
|
1752 |
+
n03967562: plow
|
1753 |
+
n03970156: plunger
|
1754 |
+
n03976467: Polaroid_camera
|
1755 |
+
n03976657: pole
|
1756 |
+
n03977966: police_van
|
1757 |
+
n03980874: poncho
|
1758 |
+
n03982430: pool_table
|
1759 |
+
n03983396: pop_bottle
|
1760 |
+
n03991062: pot
|
1761 |
+
n03992509: potter's_wheel
|
1762 |
+
n03995372: power_drill
|
1763 |
+
n03998194: prayer_rug
|
1764 |
+
n04004767: printer
|
1765 |
+
n04005630: prison
|
1766 |
+
n04008634: projectile
|
1767 |
+
n04009552: projector
|
1768 |
+
n04019541: puck
|
1769 |
+
n04023962: punching_bag
|
1770 |
+
n04026417: purse
|
1771 |
+
n04033901: quill
|
1772 |
+
n04033995: quilt
|
1773 |
+
n04037443: racer
|
1774 |
+
n04039381: racket
|
1775 |
+
n04040759: radiator
|
1776 |
+
n04041544: radio
|
1777 |
+
n04044716: radio_telescope
|
1778 |
+
n04049303: rain_barrel
|
1779 |
+
n04065272: recreational_vehicle
|
1780 |
+
n04067472: reel
|
1781 |
+
n04069434: reflex_camera
|
1782 |
+
n04070727: refrigerator
|
1783 |
+
n04074963: remote_control
|
1784 |
+
n04081281: restaurant
|
1785 |
+
n04086273: revolver
|
1786 |
+
n04090263: rifle
|
1787 |
+
n04099969: rocking_chair
|
1788 |
+
n04111531: rotisserie
|
1789 |
+
n04116512: rubber_eraser
|
1790 |
+
n04118538: rugby_ball
|
1791 |
+
n04118776: rule
|
1792 |
+
n04120489: running_shoe
|
1793 |
+
n04125021: safe
|
1794 |
+
n04127249: safety_pin
|
1795 |
+
n04131690: saltshaker
|
1796 |
+
n04133789: sandal
|
1797 |
+
n04136333: sarong
|
1798 |
+
n04141076: sax
|
1799 |
+
n04141327: scabbard
|
1800 |
+
n04141975: scale
|
1801 |
+
n04146614: school_bus
|
1802 |
+
n04147183: schooner
|
1803 |
+
n04149813: scoreboard
|
1804 |
+
n04152593: screen
|
1805 |
+
n04153751: screw
|
1806 |
+
n04154565: screwdriver
|
1807 |
+
n04162706: seat_belt
|
1808 |
+
n04179913: sewing_machine
|
1809 |
+
n04192698: shield
|
1810 |
+
n04200800: shoe_shop
|
1811 |
+
n04201297: shoji
|
1812 |
+
n04204238: shopping_basket
|
1813 |
+
n04204347: shopping_cart
|
1814 |
+
n04208210: shovel
|
1815 |
+
n04209133: shower_cap
|
1816 |
+
n04209239: shower_curtain
|
1817 |
+
n04228054: ski
|
1818 |
+
n04229816: ski_mask
|
1819 |
+
n04235860: sleeping_bag
|
1820 |
+
n04238763: slide_rule
|
1821 |
+
n04239074: sliding_door
|
1822 |
+
n04243546: slot
|
1823 |
+
n04251144: snorkel
|
1824 |
+
n04252077: snowmobile
|
1825 |
+
n04252225: snowplow
|
1826 |
+
n04254120: soap_dispenser
|
1827 |
+
n04254680: soccer_ball
|
1828 |
+
n04254777: sock
|
1829 |
+
n04258138: solar_dish
|
1830 |
+
n04259630: sombrero
|
1831 |
+
n04263257: soup_bowl
|
1832 |
+
n04264628: space_bar
|
1833 |
+
n04265275: space_heater
|
1834 |
+
n04266014: space_shuttle
|
1835 |
+
n04270147: spatula
|
1836 |
+
n04273569: speedboat
|
1837 |
+
n04275548: spider_web
|
1838 |
+
n04277352: spindle
|
1839 |
+
n04285008: sports_car
|
1840 |
+
n04286575: spotlight
|
1841 |
+
n04296562: stage
|
1842 |
+
n04310018: steam_locomotive
|
1843 |
+
n04311004: steel_arch_bridge
|
1844 |
+
n04311174: steel_drum
|
1845 |
+
n04317175: stethoscope
|
1846 |
+
n04325704: stole
|
1847 |
+
n04326547: stone_wall
|
1848 |
+
n04328186: stopwatch
|
1849 |
+
n04330267: stove
|
1850 |
+
n04332243: strainer
|
1851 |
+
n04335435: streetcar
|
1852 |
+
n04336792: stretcher
|
1853 |
+
n04344873: studio_couch
|
1854 |
+
n04346328: stupa
|
1855 |
+
n04347754: submarine
|
1856 |
+
n04350905: suit
|
1857 |
+
n04355338: sundial
|
1858 |
+
n04355933: sunglass
|
1859 |
+
n04356056: sunglasses
|
1860 |
+
n04357314: sunscreen
|
1861 |
+
n04366367: suspension_bridge
|
1862 |
+
n04367480: swab
|
1863 |
+
n04370456: sweatshirt
|
1864 |
+
n04371430: swimming_trunks
|
1865 |
+
n04371774: swing
|
1866 |
+
n04372370: switch
|
1867 |
+
n04376876: syringe
|
1868 |
+
n04380533: table_lamp
|
1869 |
+
n04389033: tank
|
1870 |
+
n04392985: tape_player
|
1871 |
+
n04398044: teapot
|
1872 |
+
n04399382: teddy
|
1873 |
+
n04404412: television
|
1874 |
+
n04409515: tennis_ball
|
1875 |
+
n04417672: thatch
|
1876 |
+
n04418357: theater_curtain
|
1877 |
+
n04423845: thimble
|
1878 |
+
n04428191: thresher
|
1879 |
+
n04429376: throne
|
1880 |
+
n04435653: tile_roof
|
1881 |
+
n04442312: toaster
|
1882 |
+
n04443257: tobacco_shop
|
1883 |
+
n04447861: toilet_seat
|
1884 |
+
n04456115: torch
|
1885 |
+
n04458633: totem_pole
|
1886 |
+
n04461696: tow_truck
|
1887 |
+
n04462240: toyshop
|
1888 |
+
n04465501: tractor
|
1889 |
+
n04467665: trailer_truck
|
1890 |
+
n04476259: tray
|
1891 |
+
n04479046: trench_coat
|
1892 |
+
n04482393: tricycle
|
1893 |
+
n04483307: trimaran
|
1894 |
+
n04485082: tripod
|
1895 |
+
n04486054: triumphal_arch
|
1896 |
+
n04487081: trolleybus
|
1897 |
+
n04487394: trombone
|
1898 |
+
n04493381: tub
|
1899 |
+
n04501370: turnstile
|
1900 |
+
n04505470: typewriter_keyboard
|
1901 |
+
n04507155: umbrella
|
1902 |
+
n04509417: unicycle
|
1903 |
+
n04515003: upright
|
1904 |
+
n04517823: vacuum
|
1905 |
+
n04522168: vase
|
1906 |
+
n04523525: vault
|
1907 |
+
n04525038: velvet
|
1908 |
+
n04525305: vending_machine
|
1909 |
+
n04532106: vestment
|
1910 |
+
n04532670: viaduct
|
1911 |
+
n04536866: violin
|
1912 |
+
n04540053: volleyball
|
1913 |
+
n04542943: waffle_iron
|
1914 |
+
n04548280: wall_clock
|
1915 |
+
n04548362: wallet
|
1916 |
+
n04550184: wardrobe
|
1917 |
+
n04552348: warplane
|
1918 |
+
n04553703: washbasin
|
1919 |
+
n04554684: washer
|
1920 |
+
n04557648: water_bottle
|
1921 |
+
n04560804: water_jug
|
1922 |
+
n04562935: water_tower
|
1923 |
+
n04579145: whiskey_jug
|
1924 |
+
n04579432: whistle
|
1925 |
+
n04584207: wig
|
1926 |
+
n04589890: window_screen
|
1927 |
+
n04590129: window_shade
|
1928 |
+
n04591157: Windsor_tie
|
1929 |
+
n04591713: wine_bottle
|
1930 |
+
n04592741: wing
|
1931 |
+
n04596742: wok
|
1932 |
+
n04597913: wooden_spoon
|
1933 |
+
n04599235: wool
|
1934 |
+
n04604644: worm_fence
|
1935 |
+
n04606251: wreck
|
1936 |
+
n04612504: yawl
|
1937 |
+
n04613696: yurt
|
1938 |
+
n06359193: web_site
|
1939 |
+
n06596364: comic_book
|
1940 |
+
n06785654: crossword_puzzle
|
1941 |
+
n06794110: street_sign
|
1942 |
+
n06874185: traffic_light
|
1943 |
+
n07248320: book_jacket
|
1944 |
+
n07565083: menu
|
1945 |
+
n07579787: plate
|
1946 |
+
n07583066: guacamole
|
1947 |
+
n07584110: consomme
|
1948 |
+
n07590611: hot_pot
|
1949 |
+
n07613480: trifle
|
1950 |
+
n07614500: ice_cream
|
1951 |
+
n07615774: ice_lolly
|
1952 |
+
n07684084: French_loaf
|
1953 |
+
n07693725: bagel
|
1954 |
+
n07695742: pretzel
|
1955 |
+
n07697313: cheeseburger
|
1956 |
+
n07697537: hotdog
|
1957 |
+
n07711569: mashed_potato
|
1958 |
+
n07714571: head_cabbage
|
1959 |
+
n07714990: broccoli
|
1960 |
+
n07715103: cauliflower
|
1961 |
+
n07716358: zucchini
|
1962 |
+
n07716906: spaghetti_squash
|
1963 |
+
n07717410: acorn_squash
|
1964 |
+
n07717556: butternut_squash
|
1965 |
+
n07718472: cucumber
|
1966 |
+
n07718747: artichoke
|
1967 |
+
n07720875: bell_pepper
|
1968 |
+
n07730033: cardoon
|
1969 |
+
n07734744: mushroom
|
1970 |
+
n07742313: Granny_Smith
|
1971 |
+
n07745940: strawberry
|
1972 |
+
n07747607: orange
|
1973 |
+
n07749582: lemon
|
1974 |
+
n07753113: fig
|
1975 |
+
n07753275: pineapple
|
1976 |
+
n07753592: banana
|
1977 |
+
n07754684: jackfruit
|
1978 |
+
n07760859: custard_apple
|
1979 |
+
n07768694: pomegranate
|
1980 |
+
n07802026: hay
|
1981 |
+
n07831146: carbonara
|
1982 |
+
n07836838: chocolate_sauce
|
1983 |
+
n07860988: dough
|
1984 |
+
n07871810: meat_loaf
|
1985 |
+
n07873807: pizza
|
1986 |
+
n07875152: potpie
|
1987 |
+
n07880968: burrito
|
1988 |
+
n07892512: red_wine
|
1989 |
+
n07920052: espresso
|
1990 |
+
n07930864: cup
|
1991 |
+
n07932039: eggnog
|
1992 |
+
n09193705: alp
|
1993 |
+
n09229709: bubble
|
1994 |
+
n09246464: cliff
|
1995 |
+
n09256479: coral_reef
|
1996 |
+
n09288635: geyser
|
1997 |
+
n09332890: lakeside
|
1998 |
+
n09399592: promontory
|
1999 |
+
n09421951: sandbar
|
2000 |
+
n09428293: seashore
|
2001 |
+
n09468604: valley
|
2002 |
+
n09472597: volcano
|
2003 |
+
n09835506: ballplayer
|
2004 |
+
n10148035: groom
|
2005 |
+
n10565667: scuba_diver
|
2006 |
+
n11879895: rapeseed
|
2007 |
+
n11939491: daisy
|
2008 |
+
n12057211: yellow_lady's_slipper
|
2009 |
+
n12144580: corn
|
2010 |
+
n12267677: acorn
|
2011 |
+
n12620546: hip
|
2012 |
+
n12768682: buckeye
|
2013 |
+
n12985857: coral_fungus
|
2014 |
+
n12998815: agaric
|
2015 |
+
n13037406: gyromitra
|
2016 |
+
n13040303: stinkhorn
|
2017 |
+
n13044778: earthstar
|
2018 |
+
n13052670: hen-of-the-woods
|
2019 |
+
n13054560: bolete
|
2020 |
+
n13133613: ear
|
2021 |
+
n15075141: toilet_tissue
|
2022 |
+
|
2023 |
+
|
2024 |
+
# Download script/URL (optional)
|
2025 |
+
download: yolo/data/scripts/get_imagenet.sh
|
ultralytics/datasets/Objects365.yaml
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# Objects365 dataset https://www.objects365.org/ by Megvii
|
3 |
+
# Example usage: yolo train data=Objects365.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── Objects365 ← downloads here (712 GB = 367G data + 345G zips)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/Objects365 # dataset root dir
|
12 |
+
train: images/train # train images (relative to 'path') 1742289 images
|
13 |
+
val: images/val # val images (relative to 'path') 80000 images
|
14 |
+
test: # test images (optional)
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: Person
|
19 |
+
1: Sneakers
|
20 |
+
2: Chair
|
21 |
+
3: Other Shoes
|
22 |
+
4: Hat
|
23 |
+
5: Car
|
24 |
+
6: Lamp
|
25 |
+
7: Glasses
|
26 |
+
8: Bottle
|
27 |
+
9: Desk
|
28 |
+
10: Cup
|
29 |
+
11: Street Lights
|
30 |
+
12: Cabinet/shelf
|
31 |
+
13: Handbag/Satchel
|
32 |
+
14: Bracelet
|
33 |
+
15: Plate
|
34 |
+
16: Picture/Frame
|
35 |
+
17: Helmet
|
36 |
+
18: Book
|
37 |
+
19: Gloves
|
38 |
+
20: Storage box
|
39 |
+
21: Boat
|
40 |
+
22: Leather Shoes
|
41 |
+
23: Flower
|
42 |
+
24: Bench
|
43 |
+
25: Potted Plant
|
44 |
+
26: Bowl/Basin
|
45 |
+
27: Flag
|
46 |
+
28: Pillow
|
47 |
+
29: Boots
|
48 |
+
30: Vase
|
49 |
+
31: Microphone
|
50 |
+
32: Necklace
|
51 |
+
33: Ring
|
52 |
+
34: SUV
|
53 |
+
35: Wine Glass
|
54 |
+
36: Belt
|
55 |
+
37: Monitor/TV
|
56 |
+
38: Backpack
|
57 |
+
39: Umbrella
|
58 |
+
40: Traffic Light
|
59 |
+
41: Speaker
|
60 |
+
42: Watch
|
61 |
+
43: Tie
|
62 |
+
44: Trash bin Can
|
63 |
+
45: Slippers
|
64 |
+
46: Bicycle
|
65 |
+
47: Stool
|
66 |
+
48: Barrel/bucket
|
67 |
+
49: Van
|
68 |
+
50: Couch
|
69 |
+
51: Sandals
|
70 |
+
52: Basket
|
71 |
+
53: Drum
|
72 |
+
54: Pen/Pencil
|
73 |
+
55: Bus
|
74 |
+
56: Wild Bird
|
75 |
+
57: High Heels
|
76 |
+
58: Motorcycle
|
77 |
+
59: Guitar
|
78 |
+
60: Carpet
|
79 |
+
61: Cell Phone
|
80 |
+
62: Bread
|
81 |
+
63: Camera
|
82 |
+
64: Canned
|
83 |
+
65: Truck
|
84 |
+
66: Traffic cone
|
85 |
+
67: Cymbal
|
86 |
+
68: Lifesaver
|
87 |
+
69: Towel
|
88 |
+
70: Stuffed Toy
|
89 |
+
71: Candle
|
90 |
+
72: Sailboat
|
91 |
+
73: Laptop
|
92 |
+
74: Awning
|
93 |
+
75: Bed
|
94 |
+
76: Faucet
|
95 |
+
77: Tent
|
96 |
+
78: Horse
|
97 |
+
79: Mirror
|
98 |
+
80: Power outlet
|
99 |
+
81: Sink
|
100 |
+
82: Apple
|
101 |
+
83: Air Conditioner
|
102 |
+
84: Knife
|
103 |
+
85: Hockey Stick
|
104 |
+
86: Paddle
|
105 |
+
87: Pickup Truck
|
106 |
+
88: Fork
|
107 |
+
89: Traffic Sign
|
108 |
+
90: Balloon
|
109 |
+
91: Tripod
|
110 |
+
92: Dog
|
111 |
+
93: Spoon
|
112 |
+
94: Clock
|
113 |
+
95: Pot
|
114 |
+
96: Cow
|
115 |
+
97: Cake
|
116 |
+
98: Dinning Table
|
117 |
+
99: Sheep
|
118 |
+
100: Hanger
|
119 |
+
101: Blackboard/Whiteboard
|
120 |
+
102: Napkin
|
121 |
+
103: Other Fish
|
122 |
+
104: Orange/Tangerine
|
123 |
+
105: Toiletry
|
124 |
+
106: Keyboard
|
125 |
+
107: Tomato
|
126 |
+
108: Lantern
|
127 |
+
109: Machinery Vehicle
|
128 |
+
110: Fan
|
129 |
+
111: Green Vegetables
|
130 |
+
112: Banana
|
131 |
+
113: Baseball Glove
|
132 |
+
114: Airplane
|
133 |
+
115: Mouse
|
134 |
+
116: Train
|
135 |
+
117: Pumpkin
|
136 |
+
118: Soccer
|
137 |
+
119: Skiboard
|
138 |
+
120: Luggage
|
139 |
+
121: Nightstand
|
140 |
+
122: Tea pot
|
141 |
+
123: Telephone
|
142 |
+
124: Trolley
|
143 |
+
125: Head Phone
|
144 |
+
126: Sports Car
|
145 |
+
127: Stop Sign
|
146 |
+
128: Dessert
|
147 |
+
129: Scooter
|
148 |
+
130: Stroller
|
149 |
+
131: Crane
|
150 |
+
132: Remote
|
151 |
+
133: Refrigerator
|
152 |
+
134: Oven
|
153 |
+
135: Lemon
|
154 |
+
136: Duck
|
155 |
+
137: Baseball Bat
|
156 |
+
138: Surveillance Camera
|
157 |
+
139: Cat
|
158 |
+
140: Jug
|
159 |
+
141: Broccoli
|
160 |
+
142: Piano
|
161 |
+
143: Pizza
|
162 |
+
144: Elephant
|
163 |
+
145: Skateboard
|
164 |
+
146: Surfboard
|
165 |
+
147: Gun
|
166 |
+
148: Skating and Skiing shoes
|
167 |
+
149: Gas stove
|
168 |
+
150: Donut
|
169 |
+
151: Bow Tie
|
170 |
+
152: Carrot
|
171 |
+
153: Toilet
|
172 |
+
154: Kite
|
173 |
+
155: Strawberry
|
174 |
+
156: Other Balls
|
175 |
+
157: Shovel
|
176 |
+
158: Pepper
|
177 |
+
159: Computer Box
|
178 |
+
160: Toilet Paper
|
179 |
+
161: Cleaning Products
|
180 |
+
162: Chopsticks
|
181 |
+
163: Microwave
|
182 |
+
164: Pigeon
|
183 |
+
165: Baseball
|
184 |
+
166: Cutting/chopping Board
|
185 |
+
167: Coffee Table
|
186 |
+
168: Side Table
|
187 |
+
169: Scissors
|
188 |
+
170: Marker
|
189 |
+
171: Pie
|
190 |
+
172: Ladder
|
191 |
+
173: Snowboard
|
192 |
+
174: Cookies
|
193 |
+
175: Radiator
|
194 |
+
176: Fire Hydrant
|
195 |
+
177: Basketball
|
196 |
+
178: Zebra
|
197 |
+
179: Grape
|
198 |
+
180: Giraffe
|
199 |
+
181: Potato
|
200 |
+
182: Sausage
|
201 |
+
183: Tricycle
|
202 |
+
184: Violin
|
203 |
+
185: Egg
|
204 |
+
186: Fire Extinguisher
|
205 |
+
187: Candy
|
206 |
+
188: Fire Truck
|
207 |
+
189: Billiards
|
208 |
+
190: Converter
|
209 |
+
191: Bathtub
|
210 |
+
192: Wheelchair
|
211 |
+
193: Golf Club
|
212 |
+
194: Briefcase
|
213 |
+
195: Cucumber
|
214 |
+
196: Cigar/Cigarette
|
215 |
+
197: Paint Brush
|
216 |
+
198: Pear
|
217 |
+
199: Heavy Truck
|
218 |
+
200: Hamburger
|
219 |
+
201: Extractor
|
220 |
+
202: Extension Cord
|
221 |
+
203: Tong
|
222 |
+
204: Tennis Racket
|
223 |
+
205: Folder
|
224 |
+
206: American Football
|
225 |
+
207: earphone
|
226 |
+
208: Mask
|
227 |
+
209: Kettle
|
228 |
+
210: Tennis
|
229 |
+
211: Ship
|
230 |
+
212: Swing
|
231 |
+
213: Coffee Machine
|
232 |
+
214: Slide
|
233 |
+
215: Carriage
|
234 |
+
216: Onion
|
235 |
+
217: Green beans
|
236 |
+
218: Projector
|
237 |
+
219: Frisbee
|
238 |
+
220: Washing Machine/Drying Machine
|
239 |
+
221: Chicken
|
240 |
+
222: Printer
|
241 |
+
223: Watermelon
|
242 |
+
224: Saxophone
|
243 |
+
225: Tissue
|
244 |
+
226: Toothbrush
|
245 |
+
227: Ice cream
|
246 |
+
228: Hot-air balloon
|
247 |
+
229: Cello
|
248 |
+
230: French Fries
|
249 |
+
231: Scale
|
250 |
+
232: Trophy
|
251 |
+
233: Cabbage
|
252 |
+
234: Hot dog
|
253 |
+
235: Blender
|
254 |
+
236: Peach
|
255 |
+
237: Rice
|
256 |
+
238: Wallet/Purse
|
257 |
+
239: Volleyball
|
258 |
+
240: Deer
|
259 |
+
241: Goose
|
260 |
+
242: Tape
|
261 |
+
243: Tablet
|
262 |
+
244: Cosmetics
|
263 |
+
245: Trumpet
|
264 |
+
246: Pineapple
|
265 |
+
247: Golf Ball
|
266 |
+
248: Ambulance
|
267 |
+
249: Parking meter
|
268 |
+
250: Mango
|
269 |
+
251: Key
|
270 |
+
252: Hurdle
|
271 |
+
253: Fishing Rod
|
272 |
+
254: Medal
|
273 |
+
255: Flute
|
274 |
+
256: Brush
|
275 |
+
257: Penguin
|
276 |
+
258: Megaphone
|
277 |
+
259: Corn
|
278 |
+
260: Lettuce
|
279 |
+
261: Garlic
|
280 |
+
262: Swan
|
281 |
+
263: Helicopter
|
282 |
+
264: Green Onion
|
283 |
+
265: Sandwich
|
284 |
+
266: Nuts
|
285 |
+
267: Speed Limit Sign
|
286 |
+
268: Induction Cooker
|
287 |
+
269: Broom
|
288 |
+
270: Trombone
|
289 |
+
271: Plum
|
290 |
+
272: Rickshaw
|
291 |
+
273: Goldfish
|
292 |
+
274: Kiwi fruit
|
293 |
+
275: Router/modem
|
294 |
+
276: Poker Card
|
295 |
+
277: Toaster
|
296 |
+
278: Shrimp
|
297 |
+
279: Sushi
|
298 |
+
280: Cheese
|
299 |
+
281: Notepaper
|
300 |
+
282: Cherry
|
301 |
+
283: Pliers
|
302 |
+
284: CD
|
303 |
+
285: Pasta
|
304 |
+
286: Hammer
|
305 |
+
287: Cue
|
306 |
+
288: Avocado
|
307 |
+
289: Hamimelon
|
308 |
+
290: Flask
|
309 |
+
291: Mushroom
|
310 |
+
292: Screwdriver
|
311 |
+
293: Soap
|
312 |
+
294: Recorder
|
313 |
+
295: Bear
|
314 |
+
296: Eggplant
|
315 |
+
297: Board Eraser
|
316 |
+
298: Coconut
|
317 |
+
299: Tape Measure/Ruler
|
318 |
+
300: Pig
|
319 |
+
301: Showerhead
|
320 |
+
302: Globe
|
321 |
+
303: Chips
|
322 |
+
304: Steak
|
323 |
+
305: Crosswalk Sign
|
324 |
+
306: Stapler
|
325 |
+
307: Camel
|
326 |
+
308: Formula 1
|
327 |
+
309: Pomegranate
|
328 |
+
310: Dishwasher
|
329 |
+
311: Crab
|
330 |
+
312: Hoverboard
|
331 |
+
313: Meat ball
|
332 |
+
314: Rice Cooker
|
333 |
+
315: Tuba
|
334 |
+
316: Calculator
|
335 |
+
317: Papaya
|
336 |
+
318: Antelope
|
337 |
+
319: Parrot
|
338 |
+
320: Seal
|
339 |
+
321: Butterfly
|
340 |
+
322: Dumbbell
|
341 |
+
323: Donkey
|
342 |
+
324: Lion
|
343 |
+
325: Urinal
|
344 |
+
326: Dolphin
|
345 |
+
327: Electric Drill
|
346 |
+
328: Hair Dryer
|
347 |
+
329: Egg tart
|
348 |
+
330: Jellyfish
|
349 |
+
331: Treadmill
|
350 |
+
332: Lighter
|
351 |
+
333: Grapefruit
|
352 |
+
334: Game board
|
353 |
+
335: Mop
|
354 |
+
336: Radish
|
355 |
+
337: Baozi
|
356 |
+
338: Target
|
357 |
+
339: French
|
358 |
+
340: Spring Rolls
|
359 |
+
341: Monkey
|
360 |
+
342: Rabbit
|
361 |
+
343: Pencil Case
|
362 |
+
344: Yak
|
363 |
+
345: Red Cabbage
|
364 |
+
346: Binoculars
|
365 |
+
347: Asparagus
|
366 |
+
348: Barbell
|
367 |
+
349: Scallop
|
368 |
+
350: Noddles
|
369 |
+
351: Comb
|
370 |
+
352: Dumpling
|
371 |
+
353: Oyster
|
372 |
+
354: Table Tennis paddle
|
373 |
+
355: Cosmetics Brush/Eyeliner Pencil
|
374 |
+
356: Chainsaw
|
375 |
+
357: Eraser
|
376 |
+
358: Lobster
|
377 |
+
359: Durian
|
378 |
+
360: Okra
|
379 |
+
361: Lipstick
|
380 |
+
362: Cosmetics Mirror
|
381 |
+
363: Curling
|
382 |
+
364: Table Tennis
|
383 |
+
|
384 |
+
|
385 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
386 |
+
download: |
|
387 |
+
from tqdm import tqdm
|
388 |
+
|
389 |
+
from ultralytics.yolo.utils.checks import check_requirements
|
390 |
+
from ultralytics.yolo.utils.downloads import download
|
391 |
+
from ultralytics.yolo.utils.ops import xyxy2xywhn
|
392 |
+
|
393 |
+
import numpy as np
|
394 |
+
from pathlib import Path
|
395 |
+
|
396 |
+
check_requirements(('pycocotools>=2.0',))
|
397 |
+
from pycocotools.coco import COCO
|
398 |
+
|
399 |
+
# Make Directories
|
400 |
+
dir = Path(yaml['path']) # dataset root dir
|
401 |
+
for p in 'images', 'labels':
|
402 |
+
(dir / p).mkdir(parents=True, exist_ok=True)
|
403 |
+
for q in 'train', 'val':
|
404 |
+
(dir / p / q).mkdir(parents=True, exist_ok=True)
|
405 |
+
|
406 |
+
# Train, Val Splits
|
407 |
+
for split, patches in [('train', 50 + 1), ('val', 43 + 1)]:
|
408 |
+
print(f"Processing {split} in {patches} patches ...")
|
409 |
+
images, labels = dir / 'images' / split, dir / 'labels' / split
|
410 |
+
|
411 |
+
# Download
|
412 |
+
url = f"https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/{split}/"
|
413 |
+
if split == 'train':
|
414 |
+
download([f'{url}zhiyuan_objv2_{split}.tar.gz'], dir=dir) # annotations json
|
415 |
+
download([f'{url}patch{i}.tar.gz' for i in range(patches)], dir=images, curl=True, threads=8)
|
416 |
+
elif split == 'val':
|
417 |
+
download([f'{url}zhiyuan_objv2_{split}.json'], dir=dir) # annotations json
|
418 |
+
download([f'{url}images/v1/patch{i}.tar.gz' for i in range(15 + 1)], dir=images, curl=True, threads=8)
|
419 |
+
download([f'{url}images/v2/patch{i}.tar.gz' for i in range(16, patches)], dir=images, curl=True, threads=8)
|
420 |
+
|
421 |
+
# Move
|
422 |
+
for f in tqdm(images.rglob('*.jpg'), desc=f'Moving {split} images'):
|
423 |
+
f.rename(images / f.name) # move to /images/{split}
|
424 |
+
|
425 |
+
# Labels
|
426 |
+
coco = COCO(dir / f'zhiyuan_objv2_{split}.json')
|
427 |
+
names = [x["name"] for x in coco.loadCats(coco.getCatIds())]
|
428 |
+
for cid, cat in enumerate(names):
|
429 |
+
catIds = coco.getCatIds(catNms=[cat])
|
430 |
+
imgIds = coco.getImgIds(catIds=catIds)
|
431 |
+
for im in tqdm(coco.loadImgs(imgIds), desc=f'Class {cid + 1}/{len(names)} {cat}'):
|
432 |
+
width, height = im["width"], im["height"]
|
433 |
+
path = Path(im["file_name"]) # image filename
|
434 |
+
try:
|
435 |
+
with open(labels / path.with_suffix('.txt').name, 'a') as file:
|
436 |
+
annIds = coco.getAnnIds(imgIds=im["id"], catIds=catIds, iscrowd=None)
|
437 |
+
for a in coco.loadAnns(annIds):
|
438 |
+
x, y, w, h = a['bbox'] # bounding box in xywh (xy top-left corner)
|
439 |
+
xyxy = np.array([x, y, x + w, y + h])[None] # pixels(1,4)
|
440 |
+
x, y, w, h = xyxy2xywhn(xyxy, w=width, h=height, clip=True)[0] # normalized and clipped
|
441 |
+
file.write(f"{cid} {x:.5f} {y:.5f} {w:.5f} {h:.5f}\n")
|
442 |
+
except Exception as e:
|
443 |
+
print(e)
|
ultralytics/datasets/SKU-110K.yaml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# SKU-110K retail items dataset https://github.com/eg4000/SKU110K_CVPR19 by Trax Retail
|
3 |
+
# Example usage: yolo train data=SKU-110K.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── SKU-110K ← downloads here (13.6 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/SKU-110K # dataset root dir
|
12 |
+
train: train.txt # train images (relative to 'path') 8219 images
|
13 |
+
val: val.txt # val images (relative to 'path') 588 images
|
14 |
+
test: test.txt # test images (optional) 2936 images
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: object
|
19 |
+
|
20 |
+
|
21 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
22 |
+
download: |
|
23 |
+
import shutil
|
24 |
+
from pathlib import Path
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import pandas as pd
|
28 |
+
from tqdm import tqdm
|
29 |
+
|
30 |
+
from ultralytics.yolo.utils.downloads import download
|
31 |
+
from ultralytics.yolo.utils.ops import xyxy2xywh
|
32 |
+
|
33 |
+
# Download
|
34 |
+
dir = Path(yaml['path']) # dataset root dir
|
35 |
+
parent = Path(dir.parent) # download dir
|
36 |
+
urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz']
|
37 |
+
download(urls, dir=parent)
|
38 |
+
|
39 |
+
# Rename directories
|
40 |
+
if dir.exists():
|
41 |
+
shutil.rmtree(dir)
|
42 |
+
(parent / 'SKU110K_fixed').rename(dir) # rename dir
|
43 |
+
(dir / 'labels').mkdir(parents=True, exist_ok=True) # create labels dir
|
44 |
+
|
45 |
+
# Convert labels
|
46 |
+
names = 'image', 'x1', 'y1', 'x2', 'y2', 'class', 'image_width', 'image_height' # column names
|
47 |
+
for d in 'annotations_train.csv', 'annotations_val.csv', 'annotations_test.csv':
|
48 |
+
x = pd.read_csv(dir / 'annotations' / d, names=names).values # annotations
|
49 |
+
images, unique_images = x[:, 0], np.unique(x[:, 0])
|
50 |
+
with open((dir / d).with_suffix('.txt').__str__().replace('annotations_', ''), 'w') as f:
|
51 |
+
f.writelines(f'./images/{s}\n' for s in unique_images)
|
52 |
+
for im in tqdm(unique_images, desc=f'Converting {dir / d}'):
|
53 |
+
cls = 0 # single-class dataset
|
54 |
+
with open((dir / 'labels' / im).with_suffix('.txt'), 'a') as f:
|
55 |
+
for r in x[images == im]:
|
56 |
+
w, h = r[6], r[7] # image width, height
|
57 |
+
xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0] # instance
|
58 |
+
f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label
|
ultralytics/datasets/VOC.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC by University of Oxford
|
3 |
+
# Example usage: yolo train data=VOC.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── VOC ← downloads here (2.8 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/VOC
|
12 |
+
train: # train images (relative to 'path') 16551 images
|
13 |
+
- images/train2012
|
14 |
+
- images/train2007
|
15 |
+
- images/val2012
|
16 |
+
- images/val2007
|
17 |
+
val: # val images (relative to 'path') 4952 images
|
18 |
+
- images/test2007
|
19 |
+
test: # test images (optional)
|
20 |
+
- images/test2007
|
21 |
+
|
22 |
+
# Classes
|
23 |
+
names:
|
24 |
+
0: aeroplane
|
25 |
+
1: bicycle
|
26 |
+
2: bird
|
27 |
+
3: boat
|
28 |
+
4: bottle
|
29 |
+
5: bus
|
30 |
+
6: car
|
31 |
+
7: cat
|
32 |
+
8: chair
|
33 |
+
9: cow
|
34 |
+
10: diningtable
|
35 |
+
11: dog
|
36 |
+
12: horse
|
37 |
+
13: motorbike
|
38 |
+
14: person
|
39 |
+
15: pottedplant
|
40 |
+
16: sheep
|
41 |
+
17: sofa
|
42 |
+
18: train
|
43 |
+
19: tvmonitor
|
44 |
+
|
45 |
+
|
46 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
47 |
+
download: |
|
48 |
+
import xml.etree.ElementTree as ET
|
49 |
+
|
50 |
+
from tqdm import tqdm
|
51 |
+
from ultralytics.yolo.utils.downloads import download
|
52 |
+
from pathlib import Path
|
53 |
+
|
54 |
+
def convert_label(path, lb_path, year, image_id):
|
55 |
+
def convert_box(size, box):
|
56 |
+
dw, dh = 1. / size[0], 1. / size[1]
|
57 |
+
x, y, w, h = (box[0] + box[1]) / 2.0 - 1, (box[2] + box[3]) / 2.0 - 1, box[1] - box[0], box[3] - box[2]
|
58 |
+
return x * dw, y * dh, w * dw, h * dh
|
59 |
+
|
60 |
+
in_file = open(path / f'VOC{year}/Annotations/{image_id}.xml')
|
61 |
+
out_file = open(lb_path, 'w')
|
62 |
+
tree = ET.parse(in_file)
|
63 |
+
root = tree.getroot()
|
64 |
+
size = root.find('size')
|
65 |
+
w = int(size.find('width').text)
|
66 |
+
h = int(size.find('height').text)
|
67 |
+
|
68 |
+
names = list(yaml['names'].values()) # names list
|
69 |
+
for obj in root.iter('object'):
|
70 |
+
cls = obj.find('name').text
|
71 |
+
if cls in names and int(obj.find('difficult').text) != 1:
|
72 |
+
xmlbox = obj.find('bndbox')
|
73 |
+
bb = convert_box((w, h), [float(xmlbox.find(x).text) for x in ('xmin', 'xmax', 'ymin', 'ymax')])
|
74 |
+
cls_id = names.index(cls) # class id
|
75 |
+
out_file.write(" ".join([str(a) for a in (cls_id, *bb)]) + '\n')
|
76 |
+
|
77 |
+
|
78 |
+
# Download
|
79 |
+
dir = Path(yaml['path']) # dataset root dir
|
80 |
+
url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
|
81 |
+
urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images
|
82 |
+
f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images
|
83 |
+
f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images
|
84 |
+
download(urls, dir=dir / 'images', curl=True, threads=3)
|
85 |
+
|
86 |
+
# Convert
|
87 |
+
path = dir / 'images/VOCdevkit'
|
88 |
+
for year, image_set in ('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test'):
|
89 |
+
imgs_path = dir / 'images' / f'{image_set}{year}'
|
90 |
+
lbs_path = dir / 'labels' / f'{image_set}{year}'
|
91 |
+
imgs_path.mkdir(exist_ok=True, parents=True)
|
92 |
+
lbs_path.mkdir(exist_ok=True, parents=True)
|
93 |
+
|
94 |
+
with open(path / f'VOC{year}/ImageSets/Main/{image_set}.txt') as f:
|
95 |
+
image_ids = f.read().strip().split()
|
96 |
+
for id in tqdm(image_ids, desc=f'{image_set}{year}'):
|
97 |
+
f = path / f'VOC{year}/JPEGImages/{id}.jpg' # old img path
|
98 |
+
lb_path = (lbs_path / f.name).with_suffix('.txt') # new label path
|
99 |
+
f.rename(imgs_path / f.name) # move image
|
100 |
+
convert_label(path, lb_path, year, id) # convert labels to YOLO format
|
ultralytics/datasets/VisDrone.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# VisDrone2019-DET dataset https://github.com/VisDrone/VisDrone-Dataset by Tianjin University
|
3 |
+
# Example usage: yolo train data=VisDrone.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── VisDrone ← downloads here (2.3 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/VisDrone # dataset root dir
|
12 |
+
train: VisDrone2019-DET-train/images # train images (relative to 'path') 6471 images
|
13 |
+
val: VisDrone2019-DET-val/images # val images (relative to 'path') 548 images
|
14 |
+
test: VisDrone2019-DET-test-dev/images # test images (optional) 1610 images
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: pedestrian
|
19 |
+
1: people
|
20 |
+
2: bicycle
|
21 |
+
3: car
|
22 |
+
4: van
|
23 |
+
5: truck
|
24 |
+
6: tricycle
|
25 |
+
7: awning-tricycle
|
26 |
+
8: bus
|
27 |
+
9: motor
|
28 |
+
|
29 |
+
|
30 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
31 |
+
download: |
|
32 |
+
import os
|
33 |
+
from pathlib import Path
|
34 |
+
|
35 |
+
from ultralytics.yolo.utils.downloads import download
|
36 |
+
|
37 |
+
def visdrone2yolo(dir):
|
38 |
+
from PIL import Image
|
39 |
+
from tqdm import tqdm
|
40 |
+
|
41 |
+
def convert_box(size, box):
|
42 |
+
# Convert VisDrone box to YOLO xywh box
|
43 |
+
dw = 1. / size[0]
|
44 |
+
dh = 1. / size[1]
|
45 |
+
return (box[0] + box[2] / 2) * dw, (box[1] + box[3] / 2) * dh, box[2] * dw, box[3] * dh
|
46 |
+
|
47 |
+
(dir / 'labels').mkdir(parents=True, exist_ok=True) # make labels directory
|
48 |
+
pbar = tqdm((dir / 'annotations').glob('*.txt'), desc=f'Converting {dir}')
|
49 |
+
for f in pbar:
|
50 |
+
img_size = Image.open((dir / 'images' / f.name).with_suffix('.jpg')).size
|
51 |
+
lines = []
|
52 |
+
with open(f, 'r') as file: # read annotation.txt
|
53 |
+
for row in [x.split(',') for x in file.read().strip().splitlines()]:
|
54 |
+
if row[4] == '0': # VisDrone 'ignored regions' class 0
|
55 |
+
continue
|
56 |
+
cls = int(row[5]) - 1
|
57 |
+
box = convert_box(img_size, tuple(map(int, row[:4])))
|
58 |
+
lines.append(f"{cls} {' '.join(f'{x:.6f}' for x in box)}\n")
|
59 |
+
with open(str(f).replace(f'{os.sep}annotations{os.sep}', f'{os.sep}labels{os.sep}'), 'w') as fl:
|
60 |
+
fl.writelines(lines) # write label.txt
|
61 |
+
|
62 |
+
|
63 |
+
# Download
|
64 |
+
dir = Path(yaml['path']) # dataset root dir
|
65 |
+
urls = ['https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-train.zip',
|
66 |
+
'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-val.zip',
|
67 |
+
'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-dev.zip',
|
68 |
+
'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-challenge.zip']
|
69 |
+
download(urls, dir=dir, curl=True, threads=4)
|
70 |
+
|
71 |
+
# Convert
|
72 |
+
for d in 'VisDrone2019-DET-train', 'VisDrone2019-DET-val', 'VisDrone2019-DET-test-dev':
|
73 |
+
visdrone2yolo(dir / d) # convert VisDrone annotations to YOLO labels
|
ultralytics/datasets/coco-pose.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO 2017 dataset http://cocodataset.org by Microsoft
|
3 |
+
# Example usage: yolo train data=coco-pose.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco-pose ← downloads here (20.1 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco-pose # dataset root dir
|
12 |
+
train: train2017.txt # train images (relative to 'path') 118287 images
|
13 |
+
val: val2017.txt # val images (relative to 'path') 5000 images
|
14 |
+
test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
|
15 |
+
|
16 |
+
# Keypoints
|
17 |
+
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
18 |
+
flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
19 |
+
|
20 |
+
# Classes
|
21 |
+
names:
|
22 |
+
0: person
|
23 |
+
|
24 |
+
# Download script/URL (optional)
|
25 |
+
download: |
|
26 |
+
from ultralytics.yolo.utils.downloads import download
|
27 |
+
from pathlib import Path
|
28 |
+
|
29 |
+
# Download labels
|
30 |
+
dir = Path(yaml['path']) # dataset root dir
|
31 |
+
url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
|
32 |
+
urls = [url + 'coco2017labels-pose.zip'] # labels
|
33 |
+
download(urls, dir=dir.parent)
|
34 |
+
# Download data
|
35 |
+
urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
|
36 |
+
'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
|
37 |
+
'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
|
38 |
+
download(urls, dir=dir / 'images', threads=3)
|
ultralytics/datasets/coco.yaml
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO 2017 dataset http://cocodataset.org by Microsoft
|
3 |
+
# Example usage: yolo train data=coco.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco ← downloads here (20.1 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco # dataset root dir
|
12 |
+
train: train2017.txt # train images (relative to 'path') 118287 images
|
13 |
+
val: val2017.txt # val images (relative to 'path') 5000 images
|
14 |
+
test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: airplane
|
23 |
+
5: bus
|
24 |
+
6: train
|
25 |
+
7: truck
|
26 |
+
8: boat
|
27 |
+
9: traffic light
|
28 |
+
10: fire hydrant
|
29 |
+
11: stop sign
|
30 |
+
12: parking meter
|
31 |
+
13: bench
|
32 |
+
14: bird
|
33 |
+
15: cat
|
34 |
+
16: dog
|
35 |
+
17: horse
|
36 |
+
18: sheep
|
37 |
+
19: cow
|
38 |
+
20: elephant
|
39 |
+
21: bear
|
40 |
+
22: zebra
|
41 |
+
23: giraffe
|
42 |
+
24: backpack
|
43 |
+
25: umbrella
|
44 |
+
26: handbag
|
45 |
+
27: tie
|
46 |
+
28: suitcase
|
47 |
+
29: frisbee
|
48 |
+
30: skis
|
49 |
+
31: snowboard
|
50 |
+
32: sports ball
|
51 |
+
33: kite
|
52 |
+
34: baseball bat
|
53 |
+
35: baseball glove
|
54 |
+
36: skateboard
|
55 |
+
37: surfboard
|
56 |
+
38: tennis racket
|
57 |
+
39: bottle
|
58 |
+
40: wine glass
|
59 |
+
41: cup
|
60 |
+
42: fork
|
61 |
+
43: knife
|
62 |
+
44: spoon
|
63 |
+
45: bowl
|
64 |
+
46: banana
|
65 |
+
47: apple
|
66 |
+
48: sandwich
|
67 |
+
49: orange
|
68 |
+
50: broccoli
|
69 |
+
51: carrot
|
70 |
+
52: hot dog
|
71 |
+
53: pizza
|
72 |
+
54: donut
|
73 |
+
55: cake
|
74 |
+
56: chair
|
75 |
+
57: couch
|
76 |
+
58: potted plant
|
77 |
+
59: bed
|
78 |
+
60: dining table
|
79 |
+
61: toilet
|
80 |
+
62: tv
|
81 |
+
63: laptop
|
82 |
+
64: mouse
|
83 |
+
65: remote
|
84 |
+
66: keyboard
|
85 |
+
67: cell phone
|
86 |
+
68: microwave
|
87 |
+
69: oven
|
88 |
+
70: toaster
|
89 |
+
71: sink
|
90 |
+
72: refrigerator
|
91 |
+
73: book
|
92 |
+
74: clock
|
93 |
+
75: vase
|
94 |
+
76: scissors
|
95 |
+
77: teddy bear
|
96 |
+
78: hair drier
|
97 |
+
79: toothbrush
|
98 |
+
|
99 |
+
|
100 |
+
# Download script/URL (optional)
|
101 |
+
download: |
|
102 |
+
from ultralytics.yolo.utils.downloads import download
|
103 |
+
from pathlib import Path
|
104 |
+
|
105 |
+
# Download labels
|
106 |
+
segments = True # segment or box labels
|
107 |
+
dir = Path(yaml['path']) # dataset root dir
|
108 |
+
url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
|
109 |
+
urls = [url + ('coco2017labels-segments.zip' if segments else 'coco2017labels.zip')] # labels
|
110 |
+
download(urls, dir=dir.parent)
|
111 |
+
# Download data
|
112 |
+
urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
|
113 |
+
'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
|
114 |
+
'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
|
115 |
+
download(urls, dir=dir / 'images', threads=3)
|
ultralytics/datasets/coco128-seg.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO128-seg dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics
|
3 |
+
# Example usage: yolo train data=coco128.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco128-seg ← downloads here (7 MB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco128-seg # dataset root dir
|
12 |
+
train: images/train2017 # train images (relative to 'path') 128 images
|
13 |
+
val: images/train2017 # val images (relative to 'path') 128 images
|
14 |
+
test: # test images (optional)
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: airplane
|
23 |
+
5: bus
|
24 |
+
6: train
|
25 |
+
7: truck
|
26 |
+
8: boat
|
27 |
+
9: traffic light
|
28 |
+
10: fire hydrant
|
29 |
+
11: stop sign
|
30 |
+
12: parking meter
|
31 |
+
13: bench
|
32 |
+
14: bird
|
33 |
+
15: cat
|
34 |
+
16: dog
|
35 |
+
17: horse
|
36 |
+
18: sheep
|
37 |
+
19: cow
|
38 |
+
20: elephant
|
39 |
+
21: bear
|
40 |
+
22: zebra
|
41 |
+
23: giraffe
|
42 |
+
24: backpack
|
43 |
+
25: umbrella
|
44 |
+
26: handbag
|
45 |
+
27: tie
|
46 |
+
28: suitcase
|
47 |
+
29: frisbee
|
48 |
+
30: skis
|
49 |
+
31: snowboard
|
50 |
+
32: sports ball
|
51 |
+
33: kite
|
52 |
+
34: baseball bat
|
53 |
+
35: baseball glove
|
54 |
+
36: skateboard
|
55 |
+
37: surfboard
|
56 |
+
38: tennis racket
|
57 |
+
39: bottle
|
58 |
+
40: wine glass
|
59 |
+
41: cup
|
60 |
+
42: fork
|
61 |
+
43: knife
|
62 |
+
44: spoon
|
63 |
+
45: bowl
|
64 |
+
46: banana
|
65 |
+
47: apple
|
66 |
+
48: sandwich
|
67 |
+
49: orange
|
68 |
+
50: broccoli
|
69 |
+
51: carrot
|
70 |
+
52: hot dog
|
71 |
+
53: pizza
|
72 |
+
54: donut
|
73 |
+
55: cake
|
74 |
+
56: chair
|
75 |
+
57: couch
|
76 |
+
58: potted plant
|
77 |
+
59: bed
|
78 |
+
60: dining table
|
79 |
+
61: toilet
|
80 |
+
62: tv
|
81 |
+
63: laptop
|
82 |
+
64: mouse
|
83 |
+
65: remote
|
84 |
+
66: keyboard
|
85 |
+
67: cell phone
|
86 |
+
68: microwave
|
87 |
+
69: oven
|
88 |
+
70: toaster
|
89 |
+
71: sink
|
90 |
+
72: refrigerator
|
91 |
+
73: book
|
92 |
+
74: clock
|
93 |
+
75: vase
|
94 |
+
76: scissors
|
95 |
+
77: teddy bear
|
96 |
+
78: hair drier
|
97 |
+
79: toothbrush
|
98 |
+
|
99 |
+
|
100 |
+
# Download script/URL (optional)
|
101 |
+
download: https://ultralytics.com/assets/coco128-seg.zip
|
ultralytics/datasets/coco128.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics
|
3 |
+
# Example usage: yolo train data=coco128.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco128 ← downloads here (7 MB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco128 # dataset root dir
|
12 |
+
train: images/train2017 # train images (relative to 'path') 128 images
|
13 |
+
val: images/train2017 # val images (relative to 'path') 128 images
|
14 |
+
test: # test images (optional)
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: airplane
|
23 |
+
5: bus
|
24 |
+
6: train
|
25 |
+
7: truck
|
26 |
+
8: boat
|
27 |
+
9: traffic light
|
28 |
+
10: fire hydrant
|
29 |
+
11: stop sign
|
30 |
+
12: parking meter
|
31 |
+
13: bench
|
32 |
+
14: bird
|
33 |
+
15: cat
|
34 |
+
16: dog
|
35 |
+
17: horse
|
36 |
+
18: sheep
|
37 |
+
19: cow
|
38 |
+
20: elephant
|
39 |
+
21: bear
|
40 |
+
22: zebra
|
41 |
+
23: giraffe
|
42 |
+
24: backpack
|
43 |
+
25: umbrella
|
44 |
+
26: handbag
|
45 |
+
27: tie
|
46 |
+
28: suitcase
|
47 |
+
29: frisbee
|
48 |
+
30: skis
|
49 |
+
31: snowboard
|
50 |
+
32: sports ball
|
51 |
+
33: kite
|
52 |
+
34: baseball bat
|
53 |
+
35: baseball glove
|
54 |
+
36: skateboard
|
55 |
+
37: surfboard
|
56 |
+
38: tennis racket
|
57 |
+
39: bottle
|
58 |
+
40: wine glass
|
59 |
+
41: cup
|
60 |
+
42: fork
|
61 |
+
43: knife
|
62 |
+
44: spoon
|
63 |
+
45: bowl
|
64 |
+
46: banana
|
65 |
+
47: apple
|
66 |
+
48: sandwich
|
67 |
+
49: orange
|
68 |
+
50: broccoli
|
69 |
+
51: carrot
|
70 |
+
52: hot dog
|
71 |
+
53: pizza
|
72 |
+
54: donut
|
73 |
+
55: cake
|
74 |
+
56: chair
|
75 |
+
57: couch
|
76 |
+
58: potted plant
|
77 |
+
59: bed
|
78 |
+
60: dining table
|
79 |
+
61: toilet
|
80 |
+
62: tv
|
81 |
+
63: laptop
|
82 |
+
64: mouse
|
83 |
+
65: remote
|
84 |
+
66: keyboard
|
85 |
+
67: cell phone
|
86 |
+
68: microwave
|
87 |
+
69: oven
|
88 |
+
70: toaster
|
89 |
+
71: sink
|
90 |
+
72: refrigerator
|
91 |
+
73: book
|
92 |
+
74: clock
|
93 |
+
75: vase
|
94 |
+
76: scissors
|
95 |
+
77: teddy bear
|
96 |
+
78: hair drier
|
97 |
+
79: toothbrush
|
98 |
+
|
99 |
+
|
100 |
+
# Download script/URL (optional)
|
101 |
+
download: https://ultralytics.com/assets/coco128.zip
|
ultralytics/datasets/coco8-pose.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO8-pose dataset (first 8 images from COCO train2017) by Ultralytics
|
3 |
+
# Example usage: yolo train data=coco8-pose.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco8-pose ← downloads here (1 MB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco8-pose # dataset root dir
|
12 |
+
train: images/train # train images (relative to 'path') 4 images
|
13 |
+
val: images/val # val images (relative to 'path') 4 images
|
14 |
+
test: # test images (optional)
|
15 |
+
|
16 |
+
# Keypoints
|
17 |
+
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
18 |
+
flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
19 |
+
|
20 |
+
# Classes
|
21 |
+
names:
|
22 |
+
0: person
|
23 |
+
|
24 |
+
# Download script/URL (optional)
|
25 |
+
download: https://ultralytics.com/assets/coco8-pose.zip
|
ultralytics/datasets/coco8-seg.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO8-seg dataset (first 8 images from COCO train2017) by Ultralytics
|
3 |
+
# Example usage: yolo train data=coco8-seg.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco8-seg ← downloads here (1 MB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco8-seg # dataset root dir
|
12 |
+
train: images/train # train images (relative to 'path') 4 images
|
13 |
+
val: images/val # val images (relative to 'path') 4 images
|
14 |
+
test: # test images (optional)
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: airplane
|
23 |
+
5: bus
|
24 |
+
6: train
|
25 |
+
7: truck
|
26 |
+
8: boat
|
27 |
+
9: traffic light
|
28 |
+
10: fire hydrant
|
29 |
+
11: stop sign
|
30 |
+
12: parking meter
|
31 |
+
13: bench
|
32 |
+
14: bird
|
33 |
+
15: cat
|
34 |
+
16: dog
|
35 |
+
17: horse
|
36 |
+
18: sheep
|
37 |
+
19: cow
|
38 |
+
20: elephant
|
39 |
+
21: bear
|
40 |
+
22: zebra
|
41 |
+
23: giraffe
|
42 |
+
24: backpack
|
43 |
+
25: umbrella
|
44 |
+
26: handbag
|
45 |
+
27: tie
|
46 |
+
28: suitcase
|
47 |
+
29: frisbee
|
48 |
+
30: skis
|
49 |
+
31: snowboard
|
50 |
+
32: sports ball
|
51 |
+
33: kite
|
52 |
+
34: baseball bat
|
53 |
+
35: baseball glove
|
54 |
+
36: skateboard
|
55 |
+
37: surfboard
|
56 |
+
38: tennis racket
|
57 |
+
39: bottle
|
58 |
+
40: wine glass
|
59 |
+
41: cup
|
60 |
+
42: fork
|
61 |
+
43: knife
|
62 |
+
44: spoon
|
63 |
+
45: bowl
|
64 |
+
46: banana
|
65 |
+
47: apple
|
66 |
+
48: sandwich
|
67 |
+
49: orange
|
68 |
+
50: broccoli
|
69 |
+
51: carrot
|
70 |
+
52: hot dog
|
71 |
+
53: pizza
|
72 |
+
54: donut
|
73 |
+
55: cake
|
74 |
+
56: chair
|
75 |
+
57: couch
|
76 |
+
58: potted plant
|
77 |
+
59: bed
|
78 |
+
60: dining table
|
79 |
+
61: toilet
|
80 |
+
62: tv
|
81 |
+
63: laptop
|
82 |
+
64: mouse
|
83 |
+
65: remote
|
84 |
+
66: keyboard
|
85 |
+
67: cell phone
|
86 |
+
68: microwave
|
87 |
+
69: oven
|
88 |
+
70: toaster
|
89 |
+
71: sink
|
90 |
+
72: refrigerator
|
91 |
+
73: book
|
92 |
+
74: clock
|
93 |
+
75: vase
|
94 |
+
76: scissors
|
95 |
+
77: teddy bear
|
96 |
+
78: hair drier
|
97 |
+
79: toothbrush
|
98 |
+
|
99 |
+
|
100 |
+
# Download script/URL (optional)
|
101 |
+
download: https://ultralytics.com/assets/coco8-seg.zip
|
ultralytics/datasets/coco8.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO8 dataset (first 8 images from COCO train2017) by Ultralytics
|
3 |
+
# Example usage: yolo train data=coco8.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco8 ← downloads here (1 MB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco8 # dataset root dir
|
12 |
+
train: images/train # train images (relative to 'path') 4 images
|
13 |
+
val: images/val # val images (relative to 'path') 4 images
|
14 |
+
test: # test images (optional)
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: airplane
|
23 |
+
5: bus
|
24 |
+
6: train
|
25 |
+
7: truck
|
26 |
+
8: boat
|
27 |
+
9: traffic light
|
28 |
+
10: fire hydrant
|
29 |
+
11: stop sign
|
30 |
+
12: parking meter
|
31 |
+
13: bench
|
32 |
+
14: bird
|
33 |
+
15: cat
|
34 |
+
16: dog
|
35 |
+
17: horse
|
36 |
+
18: sheep
|
37 |
+
19: cow
|
38 |
+
20: elephant
|
39 |
+
21: bear
|
40 |
+
22: zebra
|
41 |
+
23: giraffe
|
42 |
+
24: backpack
|
43 |
+
25: umbrella
|
44 |
+
26: handbag
|
45 |
+
27: tie
|
46 |
+
28: suitcase
|
47 |
+
29: frisbee
|
48 |
+
30: skis
|
49 |
+
31: snowboard
|
50 |
+
32: sports ball
|
51 |
+
33: kite
|
52 |
+
34: baseball bat
|
53 |
+
35: baseball glove
|
54 |
+
36: skateboard
|
55 |
+
37: surfboard
|
56 |
+
38: tennis racket
|
57 |
+
39: bottle
|
58 |
+
40: wine glass
|
59 |
+
41: cup
|
60 |
+
42: fork
|
61 |
+
43: knife
|
62 |
+
44: spoon
|
63 |
+
45: bowl
|
64 |
+
46: banana
|
65 |
+
47: apple
|
66 |
+
48: sandwich
|
67 |
+
49: orange
|
68 |
+
50: broccoli
|
69 |
+
51: carrot
|
70 |
+
52: hot dog
|
71 |
+
53: pizza
|
72 |
+
54: donut
|
73 |
+
55: cake
|
74 |
+
56: chair
|
75 |
+
57: couch
|
76 |
+
58: potted plant
|
77 |
+
59: bed
|
78 |
+
60: dining table
|
79 |
+
61: toilet
|
80 |
+
62: tv
|
81 |
+
63: laptop
|
82 |
+
64: mouse
|
83 |
+
65: remote
|
84 |
+
66: keyboard
|
85 |
+
67: cell phone
|
86 |
+
68: microwave
|
87 |
+
69: oven
|
88 |
+
70: toaster
|
89 |
+
71: sink
|
90 |
+
72: refrigerator
|
91 |
+
73: book
|
92 |
+
74: clock
|
93 |
+
75: vase
|
94 |
+
76: scissors
|
95 |
+
77: teddy bear
|
96 |
+
78: hair drier
|
97 |
+
79: toothbrush
|
98 |
+
|
99 |
+
|
100 |
+
# Download script/URL (optional)
|
101 |
+
download: https://ultralytics.com/assets/coco8.zip
|
ultralytics/datasets/xView.yaml
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# DIUx xView 2018 Challenge https://challenge.xviewdataset.org by U.S. National Geospatial-Intelligence Agency (NGA)
|
3 |
+
# -------- DOWNLOAD DATA MANUALLY and jar xf val_images.zip to 'datasets/xView' before running train command! --------
|
4 |
+
# Example usage: yolo train data=xView.yaml
|
5 |
+
# parent
|
6 |
+
# ├── ultralytics
|
7 |
+
# └── datasets
|
8 |
+
# └── xView ← downloads here (20.7 GB)
|
9 |
+
|
10 |
+
|
11 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
12 |
+
path: ../datasets/xView # dataset root dir
|
13 |
+
train: images/autosplit_train.txt # train images (relative to 'path') 90% of 847 train images
|
14 |
+
val: images/autosplit_val.txt # train images (relative to 'path') 10% of 847 train images
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: Fixed-wing Aircraft
|
19 |
+
1: Small Aircraft
|
20 |
+
2: Cargo Plane
|
21 |
+
3: Helicopter
|
22 |
+
4: Passenger Vehicle
|
23 |
+
5: Small Car
|
24 |
+
6: Bus
|
25 |
+
7: Pickup Truck
|
26 |
+
8: Utility Truck
|
27 |
+
9: Truck
|
28 |
+
10: Cargo Truck
|
29 |
+
11: Truck w/Box
|
30 |
+
12: Truck Tractor
|
31 |
+
13: Trailer
|
32 |
+
14: Truck w/Flatbed
|
33 |
+
15: Truck w/Liquid
|
34 |
+
16: Crane Truck
|
35 |
+
17: Railway Vehicle
|
36 |
+
18: Passenger Car
|
37 |
+
19: Cargo Car
|
38 |
+
20: Flat Car
|
39 |
+
21: Tank car
|
40 |
+
22: Locomotive
|
41 |
+
23: Maritime Vessel
|
42 |
+
24: Motorboat
|
43 |
+
25: Sailboat
|
44 |
+
26: Tugboat
|
45 |
+
27: Barge
|
46 |
+
28: Fishing Vessel
|
47 |
+
29: Ferry
|
48 |
+
30: Yacht
|
49 |
+
31: Container Ship
|
50 |
+
32: Oil Tanker
|
51 |
+
33: Engineering Vehicle
|
52 |
+
34: Tower crane
|
53 |
+
35: Container Crane
|
54 |
+
36: Reach Stacker
|
55 |
+
37: Straddle Carrier
|
56 |
+
38: Mobile Crane
|
57 |
+
39: Dump Truck
|
58 |
+
40: Haul Truck
|
59 |
+
41: Scraper/Tractor
|
60 |
+
42: Front loader/Bulldozer
|
61 |
+
43: Excavator
|
62 |
+
44: Cement Mixer
|
63 |
+
45: Ground Grader
|
64 |
+
46: Hut/Tent
|
65 |
+
47: Shed
|
66 |
+
48: Building
|
67 |
+
49: Aircraft Hangar
|
68 |
+
50: Damaged Building
|
69 |
+
51: Facility
|
70 |
+
52: Construction Site
|
71 |
+
53: Vehicle Lot
|
72 |
+
54: Helipad
|
73 |
+
55: Storage Tank
|
74 |
+
56: Shipping container lot
|
75 |
+
57: Shipping Container
|
76 |
+
58: Pylon
|
77 |
+
59: Tower
|
78 |
+
|
79 |
+
|
80 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
81 |
+
download: |
|
82 |
+
import json
|
83 |
+
import os
|
84 |
+
from pathlib import Path
|
85 |
+
|
86 |
+
import numpy as np
|
87 |
+
from PIL import Image
|
88 |
+
from tqdm import tqdm
|
89 |
+
|
90 |
+
from ultralytics.yolo.data.dataloaders.v5loader import autosplit
|
91 |
+
from ultralytics.yolo.utils.ops import xyxy2xywhn
|
92 |
+
|
93 |
+
|
94 |
+
def convert_labels(fname=Path('xView/xView_train.geojson')):
|
95 |
+
# Convert xView geoJSON labels to YOLO format
|
96 |
+
path = fname.parent
|
97 |
+
with open(fname) as f:
|
98 |
+
print(f'Loading {fname}...')
|
99 |
+
data = json.load(f)
|
100 |
+
|
101 |
+
# Make dirs
|
102 |
+
labels = Path(path / 'labels' / 'train')
|
103 |
+
os.system(f'rm -rf {labels}')
|
104 |
+
labels.mkdir(parents=True, exist_ok=True)
|
105 |
+
|
106 |
+
# xView classes 11-94 to 0-59
|
107 |
+
xview_class2index = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, -1, 3, -1, 4, 5, 6, 7, 8, -1, 9, 10, 11,
|
108 |
+
12, 13, 14, 15, -1, -1, 16, 17, 18, 19, 20, 21, 22, -1, 23, 24, 25, -1, 26, 27, -1, 28, -1,
|
109 |
+
29, 30, 31, 32, 33, 34, 35, 36, 37, -1, 38, 39, 40, 41, 42, 43, 44, 45, -1, -1, -1, -1, 46,
|
110 |
+
47, 48, 49, -1, 50, 51, -1, 52, -1, -1, -1, 53, 54, -1, 55, -1, -1, 56, -1, 57, -1, 58, 59]
|
111 |
+
|
112 |
+
shapes = {}
|
113 |
+
for feature in tqdm(data['features'], desc=f'Converting {fname}'):
|
114 |
+
p = feature['properties']
|
115 |
+
if p['bounds_imcoords']:
|
116 |
+
id = p['image_id']
|
117 |
+
file = path / 'train_images' / id
|
118 |
+
if file.exists(): # 1395.tif missing
|
119 |
+
try:
|
120 |
+
box = np.array([int(num) for num in p['bounds_imcoords'].split(",")])
|
121 |
+
assert box.shape[0] == 4, f'incorrect box shape {box.shape[0]}'
|
122 |
+
cls = p['type_id']
|
123 |
+
cls = xview_class2index[int(cls)] # xView class to 0-60
|
124 |
+
assert 59 >= cls >= 0, f'incorrect class index {cls}'
|
125 |
+
|
126 |
+
# Write YOLO label
|
127 |
+
if id not in shapes:
|
128 |
+
shapes[id] = Image.open(file).size
|
129 |
+
box = xyxy2xywhn(box[None].astype(np.float), w=shapes[id][0], h=shapes[id][1], clip=True)
|
130 |
+
with open((labels / id).with_suffix('.txt'), 'a') as f:
|
131 |
+
f.write(f"{cls} {' '.join(f'{x:.6f}' for x in box[0])}\n") # write label.txt
|
132 |
+
except Exception as e:
|
133 |
+
print(f'WARNING: skipping one label for {file}: {e}')
|
134 |
+
|
135 |
+
|
136 |
+
# Download manually from https://challenge.xviewdataset.org
|
137 |
+
dir = Path(yaml['path']) # dataset root dir
|
138 |
+
# urls = ['https://d307kc0mrhucc3.cloudfront.net/train_labels.zip', # train labels
|
139 |
+
# 'https://d307kc0mrhucc3.cloudfront.net/train_images.zip', # 15G, 847 train images
|
140 |
+
# 'https://d307kc0mrhucc3.cloudfront.net/val_images.zip'] # 5G, 282 val images (no labels)
|
141 |
+
# download(urls, dir=dir)
|
142 |
+
|
143 |
+
# Convert labels
|
144 |
+
convert_labels(dir / 'xView_train.geojson')
|
145 |
+
|
146 |
+
# Move images
|
147 |
+
images = Path(dir / 'images')
|
148 |
+
images.mkdir(parents=True, exist_ok=True)
|
149 |
+
Path(dir / 'train_images').rename(dir / 'images' / 'train')
|
150 |
+
Path(dir / 'val_images').rename(dir / 'images' / 'val')
|
151 |
+
|
152 |
+
# Split
|
153 |
+
autosplit(dir / 'images' / 'train')
|
ultralytics/hub/__init__.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import requests
|
4 |
+
|
5 |
+
from ultralytics.hub.auth import Auth
|
6 |
+
from ultralytics.hub.utils import PREFIX
|
7 |
+
from ultralytics.yolo.data.utils import HUBDatasetStats
|
8 |
+
from ultralytics.yolo.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
|
9 |
+
|
10 |
+
|
11 |
+
def login(api_key=''):
|
12 |
+
"""
|
13 |
+
Log in to the Ultralytics HUB API using the provided API key.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
17 |
+
|
18 |
+
Example:
|
19 |
+
from ultralytics import hub
|
20 |
+
hub.login('API_KEY')
|
21 |
+
"""
|
22 |
+
Auth(api_key, verbose=True)
|
23 |
+
|
24 |
+
|
25 |
+
def logout():
|
26 |
+
"""
|
27 |
+
Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo hub login'.
|
28 |
+
|
29 |
+
Example:
|
30 |
+
from ultralytics import hub
|
31 |
+
hub.logout()
|
32 |
+
"""
|
33 |
+
SETTINGS['api_key'] = ''
|
34 |
+
yaml_save(USER_CONFIG_DIR / 'settings.yaml', SETTINGS)
|
35 |
+
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
|
36 |
+
|
37 |
+
|
38 |
+
def start(key=''):
|
39 |
+
"""
|
40 |
+
Start training models with Ultralytics HUB (DEPRECATED).
|
41 |
+
|
42 |
+
Args:
|
43 |
+
key (str, optional): A string containing either the API key and model ID combination (apikey_modelid),
|
44 |
+
or the full model URL (https://hub.ultralytics.com/models/apikey_modelid).
|
45 |
+
"""
|
46 |
+
api_key, model_id = key.split('_')
|
47 |
+
LOGGER.warning(f"""
|
48 |
+
WARNING ⚠️ ultralytics.start() is deprecated after 8.0.60. Updated usage to train Ultralytics HUB models is:
|
49 |
+
|
50 |
+
from ultralytics import YOLO, hub
|
51 |
+
|
52 |
+
hub.login('{api_key}')
|
53 |
+
model = YOLO('https://hub.ultralytics.com/models/{model_id}')
|
54 |
+
model.train()""")
|
55 |
+
|
56 |
+
|
57 |
+
def reset_model(model_id=''):
|
58 |
+
"""Reset a trained model to an untrained state."""
|
59 |
+
r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
|
60 |
+
if r.status_code == 200:
|
61 |
+
LOGGER.info(f'{PREFIX}Model reset successfully')
|
62 |
+
return
|
63 |
+
LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
|
64 |
+
|
65 |
+
|
66 |
+
def export_fmts_hub():
|
67 |
+
"""Returns a list of HUB-supported export formats."""
|
68 |
+
from ultralytics.yolo.engine.exporter import export_formats
|
69 |
+
return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
|
70 |
+
|
71 |
+
|
72 |
+
def export_model(model_id='', format='torchscript'):
|
73 |
+
"""Export a model to all formats."""
|
74 |
+
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
75 |
+
r = requests.post(f'https://api.ultralytics.com/v1/models/{model_id}/export',
|
76 |
+
json={'format': format},
|
77 |
+
headers={'x-api-key': Auth().api_key})
|
78 |
+
assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
|
79 |
+
LOGGER.info(f'{PREFIX}{format} export started ✅')
|
80 |
+
|
81 |
+
|
82 |
+
def get_export(model_id='', format='torchscript'):
|
83 |
+
"""Get an exported model dictionary with download URL."""
|
84 |
+
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
85 |
+
r = requests.post('https://api.ultralytics.com/get-export',
|
86 |
+
json={
|
87 |
+
'apiKey': Auth().api_key,
|
88 |
+
'modelId': model_id,
|
89 |
+
'format': format})
|
90 |
+
assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
|
91 |
+
return r.json()
|
92 |
+
|
93 |
+
|
94 |
+
def check_dataset(path='', task='detect'):
|
95 |
+
"""
|
96 |
+
Function for error-checking HUB dataset Zip file before upload
|
97 |
+
|
98 |
+
Arguments
|
99 |
+
path: Path to data.zip (with data.yaml inside data.zip)
|
100 |
+
task: Dataset task. Options are 'detect', 'segment', 'pose', 'classify'.
|
101 |
+
|
102 |
+
Usage
|
103 |
+
from ultralytics.hub import check_dataset
|
104 |
+
check_dataset('path/to/coco8.zip', task='detect') # detect dataset
|
105 |
+
check_dataset('path/to/coco8-seg.zip', task='segment') # segment dataset
|
106 |
+
check_dataset('path/to/coco8-pose.zip', task='pose') # pose dataset
|
107 |
+
"""
|
108 |
+
HUBDatasetStats(path=path, task=task).get_json()
|
109 |
+
LOGGER.info('Checks completed correctly ✅. Upload this dataset to https://hub.ultralytics.com/datasets/.')
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == '__main__':
|
113 |
+
start()
|
ultralytics/hub/auth.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import requests
|
4 |
+
|
5 |
+
from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, request_with_credentials
|
6 |
+
from ultralytics.yolo.utils import LOGGER, SETTINGS, emojis, is_colab, set_settings
|
7 |
+
|
8 |
+
API_KEY_URL = 'https://hub.ultralytics.com/settings?tab=api+keys'
|
9 |
+
|
10 |
+
|
11 |
+
class Auth:
|
12 |
+
id_token = api_key = model_key = False
|
13 |
+
|
14 |
+
def __init__(self, api_key='', verbose=False):
|
15 |
+
"""
|
16 |
+
Initialize the Auth class with an optional API key.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
20 |
+
"""
|
21 |
+
# Split the input API key in case it contains a combined key_model and keep only the API key part
|
22 |
+
api_key = api_key.split('_')[0]
|
23 |
+
|
24 |
+
# Set API key attribute as value passed or SETTINGS API key if none passed
|
25 |
+
self.api_key = api_key or SETTINGS.get('api_key', '')
|
26 |
+
|
27 |
+
# If an API key is provided
|
28 |
+
if self.api_key:
|
29 |
+
# If the provided API key matches the API key in the SETTINGS
|
30 |
+
if self.api_key == SETTINGS.get('api_key'):
|
31 |
+
# Log that the user is already logged in
|
32 |
+
if verbose:
|
33 |
+
LOGGER.info(f'{PREFIX}Authenticated ✅')
|
34 |
+
return
|
35 |
+
else:
|
36 |
+
# Attempt to authenticate with the provided API key
|
37 |
+
success = self.authenticate()
|
38 |
+
# If the API key is not provided and the environment is a Google Colab notebook
|
39 |
+
elif is_colab():
|
40 |
+
# Attempt to authenticate using browser cookies
|
41 |
+
success = self.auth_with_cookies()
|
42 |
+
else:
|
43 |
+
# Request an API key
|
44 |
+
success = self.request_api_key()
|
45 |
+
|
46 |
+
# Update SETTINGS with the new API key after successful authentication
|
47 |
+
if success:
|
48 |
+
set_settings({'api_key': self.api_key})
|
49 |
+
# Log that the new login was successful
|
50 |
+
if verbose:
|
51 |
+
LOGGER.info(f'{PREFIX}New authentication successful ✅')
|
52 |
+
elif verbose:
|
53 |
+
LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}')
|
54 |
+
|
55 |
+
def request_api_key(self, max_attempts=3):
|
56 |
+
"""
|
57 |
+
Prompt the user to input their API key. Returns the model ID.
|
58 |
+
"""
|
59 |
+
import getpass
|
60 |
+
for attempts in range(max_attempts):
|
61 |
+
LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
|
62 |
+
input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ')
|
63 |
+
self.api_key = input_key.split('_')[0] # remove model id if present
|
64 |
+
if self.authenticate():
|
65 |
+
return True
|
66 |
+
raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
|
67 |
+
|
68 |
+
def authenticate(self) -> bool:
|
69 |
+
"""
|
70 |
+
Attempt to authenticate with the server using either id_token or API key.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
bool: True if authentication is successful, False otherwise.
|
74 |
+
"""
|
75 |
+
try:
|
76 |
+
header = self.get_auth_header()
|
77 |
+
if header:
|
78 |
+
r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
|
79 |
+
if not r.json().get('success', False):
|
80 |
+
raise ConnectionError('Unable to authenticate.')
|
81 |
+
return True
|
82 |
+
raise ConnectionError('User has not authenticated locally.')
|
83 |
+
except ConnectionError:
|
84 |
+
self.id_token = self.api_key = False # reset invalid
|
85 |
+
LOGGER.warning(f'{PREFIX}Invalid API key ⚠️')
|
86 |
+
return False
|
87 |
+
|
88 |
+
def auth_with_cookies(self) -> bool:
|
89 |
+
"""
|
90 |
+
Attempt to fetch authentication via cookies and set id_token.
|
91 |
+
User must be logged in to HUB and running in a supported browser.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
bool: True if authentication is successful, False otherwise.
|
95 |
+
"""
|
96 |
+
if not is_colab():
|
97 |
+
return False # Currently only works with Colab
|
98 |
+
try:
|
99 |
+
authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
|
100 |
+
if authn.get('success', False):
|
101 |
+
self.id_token = authn.get('data', {}).get('idToken', None)
|
102 |
+
self.authenticate()
|
103 |
+
return True
|
104 |
+
raise ConnectionError('Unable to fetch browser authentication details.')
|
105 |
+
except ConnectionError:
|
106 |
+
self.id_token = False # reset invalid
|
107 |
+
return False
|
108 |
+
|
109 |
+
def get_auth_header(self):
|
110 |
+
"""
|
111 |
+
Get the authentication header for making API requests.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
(dict): The authentication header if id_token or API key is set, None otherwise.
|
115 |
+
"""
|
116 |
+
if self.id_token:
|
117 |
+
return {'authorization': f'Bearer {self.id_token}'}
|
118 |
+
elif self.api_key:
|
119 |
+
return {'x-api-key': self.api_key}
|
120 |
+
else:
|
121 |
+
return None
|
122 |
+
|
123 |
+
def get_state(self) -> bool:
|
124 |
+
"""
|
125 |
+
Get the authentication state.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
bool: True if either id_token or API key is set, False otherwise.
|
129 |
+
"""
|
130 |
+
return self.id_token or self.api_key
|
131 |
+
|
132 |
+
def set_api_key(self, key: str):
|
133 |
+
"""
|
134 |
+
Set the API key for authentication.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
key (str): The API key string.
|
138 |
+
"""
|
139 |
+
self.api_key = key
|
ultralytics/hub/session.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
import signal
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
from time import sleep
|
6 |
+
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, smart_request
|
10 |
+
from ultralytics.yolo.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
|
11 |
+
from ultralytics.yolo.utils.errors import HUBModelError
|
12 |
+
|
13 |
+
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
|
14 |
+
|
15 |
+
|
16 |
+
class HUBTrainingSession:
|
17 |
+
"""
|
18 |
+
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
url (str): Model identifier used to initialize the HUB training session.
|
22 |
+
|
23 |
+
Attributes:
|
24 |
+
agent_id (str): Identifier for the instance communicating with the server.
|
25 |
+
model_id (str): Identifier for the YOLOv5 model being trained.
|
26 |
+
model_url (str): URL for the model in Ultralytics HUB.
|
27 |
+
api_url (str): API URL for the model in Ultralytics HUB.
|
28 |
+
auth_header (Dict): Authentication header for the Ultralytics HUB API requests.
|
29 |
+
rate_limits (Dict): Rate limits for different API calls (in seconds).
|
30 |
+
timers (Dict): Timers for rate limiting.
|
31 |
+
metrics_queue (Dict): Queue for the model's metrics.
|
32 |
+
model (Dict): Model data fetched from Ultralytics HUB.
|
33 |
+
alive (bool): Indicates if the heartbeat loop is active.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, url):
|
37 |
+
"""
|
38 |
+
Initialize the HUBTrainingSession with the provided model identifier.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
url (str): Model identifier used to initialize the HUB training session.
|
42 |
+
It can be a URL string or a model key with specific format.
|
43 |
+
|
44 |
+
Raises:
|
45 |
+
ValueError: If the provided model identifier is invalid.
|
46 |
+
ConnectionError: If connecting with global API key is not supported.
|
47 |
+
"""
|
48 |
+
|
49 |
+
from ultralytics.hub.auth import Auth
|
50 |
+
|
51 |
+
# Parse input
|
52 |
+
if url.startswith('https://hub.ultralytics.com/models/'):
|
53 |
+
url = url.split('https://hub.ultralytics.com/models/')[-1]
|
54 |
+
if [len(x) for x in url.split('_')] == [42, 20]:
|
55 |
+
key, model_id = url.split('_')
|
56 |
+
elif len(url) == 20:
|
57 |
+
key, model_id = '', url
|
58 |
+
else:
|
59 |
+
raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
|
60 |
+
f"model='https://hub.ultralytics.com/models/MODEL_ID' and try again.")
|
61 |
+
|
62 |
+
# Authorize
|
63 |
+
auth = Auth(key)
|
64 |
+
self.agent_id = None # identifies which instance is communicating with server
|
65 |
+
self.model_id = model_id
|
66 |
+
self.model_url = f'https://hub.ultralytics.com/models/{model_id}'
|
67 |
+
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
|
68 |
+
self.auth_header = auth.get_auth_header()
|
69 |
+
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
|
70 |
+
self.timers = {} # rate limit timers (seconds)
|
71 |
+
self.metrics_queue = {} # metrics queue
|
72 |
+
self.model = self._get_model()
|
73 |
+
self.alive = True
|
74 |
+
self._start_heartbeat() # start heartbeats
|
75 |
+
self._register_signal_handlers()
|
76 |
+
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
|
77 |
+
|
78 |
+
def _register_signal_handlers(self):
|
79 |
+
"""Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
|
80 |
+
signal.signal(signal.SIGTERM, self._handle_signal)
|
81 |
+
signal.signal(signal.SIGINT, self._handle_signal)
|
82 |
+
|
83 |
+
def _handle_signal(self, signum, frame):
|
84 |
+
"""
|
85 |
+
Handle kill signals and prevent heartbeats from being sent on Colab after termination.
|
86 |
+
This method does not use frame, it is included as it is passed by signal.
|
87 |
+
"""
|
88 |
+
if self.alive is True:
|
89 |
+
LOGGER.info(f'{PREFIX}Kill signal received! ❌')
|
90 |
+
self._stop_heartbeat()
|
91 |
+
sys.exit(signum)
|
92 |
+
|
93 |
+
def _stop_heartbeat(self):
|
94 |
+
"""Terminate the heartbeat loop."""
|
95 |
+
self.alive = False
|
96 |
+
|
97 |
+
def upload_metrics(self):
|
98 |
+
"""Upload model metrics to Ultralytics HUB."""
|
99 |
+
payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
|
100 |
+
smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
|
101 |
+
|
102 |
+
def _get_model(self):
|
103 |
+
"""Fetch and return model data from Ultralytics HUB."""
|
104 |
+
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
|
105 |
+
|
106 |
+
try:
|
107 |
+
response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0)
|
108 |
+
data = response.json().get('data', None)
|
109 |
+
|
110 |
+
if data.get('status', None) == 'trained':
|
111 |
+
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
|
112 |
+
|
113 |
+
if not data.get('data', None):
|
114 |
+
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
|
115 |
+
self.model_id = data['id']
|
116 |
+
|
117 |
+
if data['status'] == 'new': # new model to start training
|
118 |
+
self.train_args = {
|
119 |
+
# TODO: deprecate 'batch_size' key for 'batch' in 3Q23
|
120 |
+
'batch': data['batch' if ('batch' in data) else 'batch_size'],
|
121 |
+
'epochs': data['epochs'],
|
122 |
+
'imgsz': data['imgsz'],
|
123 |
+
'patience': data['patience'],
|
124 |
+
'device': data['device'],
|
125 |
+
'cache': data['cache'],
|
126 |
+
'data': data['data']}
|
127 |
+
self.model_file = data.get('cfg') or data.get('weights') # cfg for pretrained=False
|
128 |
+
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
|
129 |
+
elif data['status'] == 'training': # existing model to resume training
|
130 |
+
self.train_args = {'data': data['data'], 'resume': True}
|
131 |
+
self.model_file = data['resume']
|
132 |
+
|
133 |
+
return data
|
134 |
+
except requests.exceptions.ConnectionError as e:
|
135 |
+
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
|
136 |
+
except Exception:
|
137 |
+
raise
|
138 |
+
|
139 |
+
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
140 |
+
"""
|
141 |
+
Upload a model checkpoint to Ultralytics HUB.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
epoch (int): The current training epoch.
|
145 |
+
weights (str): Path to the model weights file.
|
146 |
+
is_best (bool): Indicates if the current model is the best one so far.
|
147 |
+
map (float): Mean average precision of the model.
|
148 |
+
final (bool): Indicates if the model is the final model after training.
|
149 |
+
"""
|
150 |
+
if Path(weights).is_file():
|
151 |
+
with open(weights, 'rb') as f:
|
152 |
+
file = f.read()
|
153 |
+
else:
|
154 |
+
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
|
155 |
+
file = None
|
156 |
+
url = f'{self.api_url}/upload'
|
157 |
+
# url = 'http://httpbin.org/post' # for debug
|
158 |
+
data = {'epoch': epoch}
|
159 |
+
if final:
|
160 |
+
data.update({'type': 'final', 'map': map})
|
161 |
+
smart_request('post',
|
162 |
+
url,
|
163 |
+
data=data,
|
164 |
+
files={'best.pt': file},
|
165 |
+
headers=self.auth_header,
|
166 |
+
retry=10,
|
167 |
+
timeout=3600,
|
168 |
+
thread=False,
|
169 |
+
progress=True,
|
170 |
+
code=4)
|
171 |
+
else:
|
172 |
+
data.update({'type': 'epoch', 'isBest': bool(is_best)})
|
173 |
+
smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3)
|
174 |
+
|
175 |
+
@threaded
|
176 |
+
def _start_heartbeat(self):
|
177 |
+
"""Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
|
178 |
+
while self.alive:
|
179 |
+
r = smart_request('post',
|
180 |
+
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
181 |
+
json={
|
182 |
+
'agent': AGENT_NAME,
|
183 |
+
'agentId': self.agent_id},
|
184 |
+
headers=self.auth_header,
|
185 |
+
retry=0,
|
186 |
+
code=5,
|
187 |
+
thread=False) # already in a thread
|
188 |
+
self.agent_id = r.json().get('data', {}).get('agentId', None)
|
189 |
+
sleep(self.rate_limits['heartbeat'])
|
ultralytics/hub/utils.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import os
|
4 |
+
import platform
|
5 |
+
import random
|
6 |
+
import sys
|
7 |
+
import threading
|
8 |
+
import time
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import requests
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from ultralytics.yolo.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM_BAR_FORMAT,
|
15 |
+
TryExcept, __version__, colorstr, get_git_origin_url, is_colab, is_git_dir,
|
16 |
+
is_pip_package)
|
17 |
+
|
18 |
+
PREFIX = colorstr('Ultralytics HUB: ')
|
19 |
+
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
20 |
+
HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
|
21 |
+
|
22 |
+
|
23 |
+
def request_with_credentials(url: str) -> any:
|
24 |
+
"""
|
25 |
+
Make an AJAX request with cookies attached in a Google Colab environment.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
url (str): The URL to make the request to.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
(any): The response data from the AJAX request.
|
32 |
+
|
33 |
+
Raises:
|
34 |
+
OSError: If the function is not run in a Google Colab environment.
|
35 |
+
"""
|
36 |
+
if not is_colab():
|
37 |
+
raise OSError('request_with_credentials() must run in a Colab environment')
|
38 |
+
from google.colab import output # noqa
|
39 |
+
from IPython import display # noqa
|
40 |
+
display.display(
|
41 |
+
display.Javascript("""
|
42 |
+
window._hub_tmp = new Promise((resolve, reject) => {
|
43 |
+
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
|
44 |
+
fetch("%s", {
|
45 |
+
method: 'POST',
|
46 |
+
credentials: 'include'
|
47 |
+
})
|
48 |
+
.then((response) => resolve(response.json()))
|
49 |
+
.then((json) => {
|
50 |
+
clearTimeout(timeout);
|
51 |
+
}).catch((err) => {
|
52 |
+
clearTimeout(timeout);
|
53 |
+
reject(err);
|
54 |
+
});
|
55 |
+
});
|
56 |
+
""" % url))
|
57 |
+
return output.eval_js('_hub_tmp')
|
58 |
+
|
59 |
+
|
60 |
+
def requests_with_progress(method, url, **kwargs):
|
61 |
+
"""
|
62 |
+
Make an HTTP request using the specified method and URL, with an optional progress bar.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
method (str): The HTTP method to use (e.g. 'GET', 'POST').
|
66 |
+
url (str): The URL to send the request to.
|
67 |
+
**kwargs (dict): Additional keyword arguments to pass to the underlying `requests.request` function.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
(requests.Response): The response object from the HTTP request.
|
71 |
+
|
72 |
+
Note:
|
73 |
+
If 'progress' is set to True, the progress bar will display the download progress
|
74 |
+
for responses with a known content length.
|
75 |
+
"""
|
76 |
+
progress = kwargs.pop('progress', False)
|
77 |
+
if not progress:
|
78 |
+
return requests.request(method, url, **kwargs)
|
79 |
+
response = requests.request(method, url, stream=True, **kwargs)
|
80 |
+
total = int(response.headers.get('content-length', 0)) # total size
|
81 |
+
pbar = tqdm(total=total, unit='B', unit_scale=True, unit_divisor=1024, bar_format=TQDM_BAR_FORMAT)
|
82 |
+
for data in response.iter_content(chunk_size=1024):
|
83 |
+
pbar.update(len(data))
|
84 |
+
pbar.close()
|
85 |
+
return response
|
86 |
+
|
87 |
+
|
88 |
+
def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs):
|
89 |
+
"""
|
90 |
+
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
|
94 |
+
url (str): The URL to make the request to.
|
95 |
+
retry (int, optional): Number of retries to attempt before giving up. Default is 3.
|
96 |
+
timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30.
|
97 |
+
thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True.
|
98 |
+
code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
|
99 |
+
verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
|
100 |
+
progress (bool, optional): Whether to show a progress bar during the request. Default is False.
|
101 |
+
**kwargs (dict): Keyword arguments to be passed to the requests function specified in method.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
(requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None.
|
105 |
+
"""
|
106 |
+
retry_codes = (408, 500) # retry only these codes
|
107 |
+
|
108 |
+
@TryExcept(verbose=verbose)
|
109 |
+
def func(func_method, func_url, **func_kwargs):
|
110 |
+
"""Make HTTP requests with retries and timeouts, with optional progress tracking."""
|
111 |
+
r = None # response
|
112 |
+
t0 = time.time() # initial time for timer
|
113 |
+
for i in range(retry + 1):
|
114 |
+
if (time.time() - t0) > timeout:
|
115 |
+
break
|
116 |
+
r = requests_with_progress(func_method, func_url, **func_kwargs) # i.e. get(url, data, json, files)
|
117 |
+
if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful"
|
118 |
+
break
|
119 |
+
try:
|
120 |
+
m = r.json().get('message', 'No JSON message.')
|
121 |
+
except AttributeError:
|
122 |
+
m = 'Unable to read JSON.'
|
123 |
+
if i == 0:
|
124 |
+
if r.status_code in retry_codes:
|
125 |
+
m += f' Retrying {retry}x for {timeout}s.' if retry else ''
|
126 |
+
elif r.status_code == 429: # rate limit
|
127 |
+
h = r.headers # response headers
|
128 |
+
m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
|
129 |
+
f"Please retry after {h['Retry-After']}s."
|
130 |
+
if verbose:
|
131 |
+
LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
|
132 |
+
if r.status_code not in retry_codes:
|
133 |
+
return r
|
134 |
+
time.sleep(2 ** i) # exponential standoff
|
135 |
+
return r
|
136 |
+
|
137 |
+
args = method, url
|
138 |
+
kwargs['progress'] = progress
|
139 |
+
if thread:
|
140 |
+
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
|
141 |
+
else:
|
142 |
+
return func(*args, **kwargs)
|
143 |
+
|
144 |
+
|
145 |
+
class Events:
|
146 |
+
"""
|
147 |
+
A class for collecting anonymous event analytics. Event analytics are enabled when sync=True in settings and
|
148 |
+
disabled when sync=False. Run 'yolo settings' to see and update settings YAML file.
|
149 |
+
|
150 |
+
Attributes:
|
151 |
+
url (str): The URL to send anonymous events.
|
152 |
+
rate_limit (float): The rate limit in seconds for sending events.
|
153 |
+
metadata (dict): A dictionary containing metadata about the environment.
|
154 |
+
enabled (bool): A flag to enable or disable Events based on certain conditions.
|
155 |
+
"""
|
156 |
+
|
157 |
+
url = 'https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw'
|
158 |
+
|
159 |
+
def __init__(self):
|
160 |
+
"""
|
161 |
+
Initializes the Events object with default values for events, rate_limit, and metadata.
|
162 |
+
"""
|
163 |
+
self.events = [] # events list
|
164 |
+
self.rate_limit = 60.0 # rate limit (seconds)
|
165 |
+
self.t = 0.0 # rate limit timer (seconds)
|
166 |
+
self.metadata = {
|
167 |
+
'cli': Path(sys.argv[0]).name == 'yolo',
|
168 |
+
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
|
169 |
+
'python': '.'.join(platform.python_version_tuple()[:2]), # i.e. 3.10
|
170 |
+
'version': __version__,
|
171 |
+
'env': ENVIRONMENT,
|
172 |
+
'session_id': round(random.random() * 1E15),
|
173 |
+
'engagement_time_msec': 1000}
|
174 |
+
self.enabled = \
|
175 |
+
SETTINGS['sync'] and \
|
176 |
+
RANK in (-1, 0) and \
|
177 |
+
not TESTS_RUNNING and \
|
178 |
+
ONLINE and \
|
179 |
+
(is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
|
180 |
+
|
181 |
+
def __call__(self, cfg):
|
182 |
+
"""
|
183 |
+
Attempts to add a new event to the events list and send events if the rate limit is reached.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
|
187 |
+
"""
|
188 |
+
if not self.enabled:
|
189 |
+
# Events disabled, do nothing
|
190 |
+
return
|
191 |
+
|
192 |
+
# Attempt to add to events
|
193 |
+
if len(self.events) < 25: # Events list limited to 25 events (drop any events past this)
|
194 |
+
params = {**self.metadata, **{'task': cfg.task}}
|
195 |
+
if cfg.mode == 'export':
|
196 |
+
params['format'] = cfg.format
|
197 |
+
self.events.append({'name': cfg.mode, 'params': params})
|
198 |
+
|
199 |
+
# Check rate limit
|
200 |
+
t = time.time()
|
201 |
+
if (t - self.t) < self.rate_limit:
|
202 |
+
# Time is under rate limiter, wait to send
|
203 |
+
return
|
204 |
+
|
205 |
+
# Time is over rate limiter, send now
|
206 |
+
data = {'client_id': SETTINGS['uuid'], 'events': self.events} # SHA-256 anonymized UUID hash and events list
|
207 |
+
|
208 |
+
# POST equivalent to requests.post(self.url, json=data)
|
209 |
+
smart_request('post', self.url, json=data, retry=0, verbose=False)
|
210 |
+
|
211 |
+
# Reset events and rate limit timer
|
212 |
+
self.events = []
|
213 |
+
self.t = t
|
214 |
+
|
215 |
+
|
216 |
+
# Run below code on hub/utils init -------------------------------------------------------------------------------------
|
217 |
+
events = Events()
|
ultralytics/models/README.md
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Models
|
2 |
+
|
3 |
+
Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration
|
4 |
+
files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted
|
5 |
+
and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image
|
6 |
+
segmentation tasks.
|
7 |
+
|
8 |
+
These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like
|
9 |
+
instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms,
|
10 |
+
from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this
|
11 |
+
directory provides a great starting point for your custom model development needs.
|
12 |
+
|
13 |
+
To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've
|
14 |
+
selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full
|
15 |
+
details at the Ultralytics [Docs](https://docs.ultralytics.com/models), and if you need help or have any questions, feel free
|
16 |
+
to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now!
|
17 |
+
|
18 |
+
### Usage
|
19 |
+
|
20 |
+
Model `*.yaml` files may be used directly in the Command Line Interface (CLI) with a `yolo` command:
|
21 |
+
|
22 |
+
```bash
|
23 |
+
yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100
|
24 |
+
```
|
25 |
+
|
26 |
+
They may also be used directly in a Python environment, and accepts the same
|
27 |
+
[arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above:
|
28 |
+
|
29 |
+
```python
|
30 |
+
from ultralytics import YOLO
|
31 |
+
|
32 |
+
model = YOLO("model.yaml") # build a YOLOv8n model from scratch
|
33 |
+
# YOLO("model.pt") use pre-trained model if available
|
34 |
+
model.info() # display model information
|
35 |
+
model.train(data="coco128.yaml", epochs=100) # train the model
|
36 |
+
```
|
37 |
+
|
38 |
+
## Pre-trained Model Architectures
|
39 |
+
|
40 |
+
Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information
|
41 |
+
and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available.
|
42 |
+
|
43 |
+
## Contributing New Models
|
44 |
+
|
45 |
+
If you've developed a new model architecture or have improvements for existing models that you'd like to contribute to the Ultralytics community, please submit your contribution in a new Pull Request. For more details, visit our [Contributing Guide](https://docs.ultralytics.com/help/contributing).
|
ultralytics/models/rt-detr/rt-detr-l.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
l: [1.00, 1.00, 1024]
|
9 |
+
|
10 |
+
backbone:
|
11 |
+
# [from, repeats, module, args]
|
12 |
+
- [-1, 1, HGStem, [32, 48]] # 0-P2/4
|
13 |
+
- [-1, 6, HGBlock, [48, 128, 3]] # stage 1
|
14 |
+
|
15 |
+
- [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
|
16 |
+
- [-1, 6, HGBlock, [96, 512, 3]] # stage 2
|
17 |
+
|
18 |
+
- [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16
|
19 |
+
- [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
|
20 |
+
- [-1, 6, HGBlock, [192, 1024, 5, True, True]]
|
21 |
+
- [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3
|
22 |
+
|
23 |
+
- [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32
|
24 |
+
- [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4
|
25 |
+
|
26 |
+
head:
|
27 |
+
- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
|
28 |
+
- [-1, 1, AIFI, [1024, 8]]
|
29 |
+
- [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0
|
30 |
+
|
31 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
32 |
+
- [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
|
33 |
+
- [[-2, -1], 1, Concat, [1]]
|
34 |
+
- [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
|
35 |
+
- [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1
|
36 |
+
|
37 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
38 |
+
- [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
|
39 |
+
- [[-2, -1], 1, Concat, [1]] # cat backbone P4
|
40 |
+
- [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1
|
41 |
+
|
42 |
+
- [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
|
43 |
+
- [[-1, 17], 1, Concat, [1]] # cat Y4
|
44 |
+
- [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0
|
45 |
+
|
46 |
+
- [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
|
47 |
+
- [[-1, 12], 1, Concat, [1]] # cat Y5
|
48 |
+
- [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1
|
49 |
+
|
50 |
+
- [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
|
ultralytics/models/rt-detr/rt-detr-x.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# RT-DETR-x object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
x: [1.00, 1.00, 2048]
|
9 |
+
|
10 |
+
backbone:
|
11 |
+
# [from, repeats, module, args]
|
12 |
+
- [-1, 1, HGStem, [32, 64]] # 0-P2/4
|
13 |
+
- [-1, 6, HGBlock, [64, 128, 3]] # stage 1
|
14 |
+
|
15 |
+
- [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
|
16 |
+
- [-1, 6, HGBlock, [128, 512, 3]]
|
17 |
+
- [-1, 6, HGBlock, [128, 512, 3, False, True]] # 4-stage 2
|
18 |
+
|
19 |
+
- [-1, 1, DWConv, [512, 3, 2, 1, False]] # 5-P3/16
|
20 |
+
- [-1, 6, HGBlock, [256, 1024, 5, True, False]] # cm, c2, k, light, shortcut
|
21 |
+
- [-1, 6, HGBlock, [256, 1024, 5, True, True]]
|
22 |
+
- [-1, 6, HGBlock, [256, 1024, 5, True, True]]
|
23 |
+
- [-1, 6, HGBlock, [256, 1024, 5, True, True]]
|
24 |
+
- [-1, 6, HGBlock, [256, 1024, 5, True, True]] # 10-stage 3
|
25 |
+
|
26 |
+
- [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 11-P4/32
|
27 |
+
- [-1, 6, HGBlock, [512, 2048, 5, True, False]]
|
28 |
+
- [-1, 6, HGBlock, [512, 2048, 5, True, True]] # 13-stage 4
|
29 |
+
|
30 |
+
head:
|
31 |
+
- [-1, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 14 input_proj.2
|
32 |
+
- [-1, 1, AIFI, [2048, 8]]
|
33 |
+
- [-1, 1, Conv, [384, 1, 1]] # 16, Y5, lateral_convs.0
|
34 |
+
|
35 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
36 |
+
- [10, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 18 input_proj.1
|
37 |
+
- [[-2, -1], 1, Concat, [1]]
|
38 |
+
- [-1, 3, RepC3, [384]] # 20, fpn_blocks.0
|
39 |
+
- [-1, 1, Conv, [384, 1, 1]] # 21, Y4, lateral_convs.1
|
40 |
+
|
41 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
42 |
+
- [4, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 23 input_proj.0
|
43 |
+
- [[-2, -1], 1, Concat, [1]] # cat backbone P4
|
44 |
+
- [-1, 3, RepC3, [384]] # X3 (25), fpn_blocks.1
|
45 |
+
|
46 |
+
- [-1, 1, Conv, [384, 3, 2]] # 26, downsample_convs.0
|
47 |
+
- [[-1, 21], 1, Concat, [1]] # cat Y4
|
48 |
+
- [-1, 3, RepC3, [384]] # F4 (28), pan_blocks.0
|
49 |
+
|
50 |
+
- [-1, 1, Conv, [384, 3, 2]] # 29, downsample_convs.1
|
51 |
+
- [[-1, 16], 1, Concat, [1]] # cat Y5
|
52 |
+
- [-1, 3, RepC3, [384]] # F5 (31), pan_blocks.1
|
53 |
+
|
54 |
+
- [[25, 28, 31], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
|
ultralytics/models/v3/yolov3-spp.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv3-SPP object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
depth_multiple: 1.0 # model depth multiple
|
7 |
+
width_multiple: 1.0 # layer channel multiple
|
8 |
+
|
9 |
+
# darknet53 backbone
|
10 |
+
backbone:
|
11 |
+
# [from, number, module, args]
|
12 |
+
[[-1, 1, Conv, [32, 3, 1]], # 0
|
13 |
+
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2
|
14 |
+
[-1, 1, Bottleneck, [64]],
|
15 |
+
[-1, 1, Conv, [128, 3, 2]], # 3-P2/4
|
16 |
+
[-1, 2, Bottleneck, [128]],
|
17 |
+
[-1, 1, Conv, [256, 3, 2]], # 5-P3/8
|
18 |
+
[-1, 8, Bottleneck, [256]],
|
19 |
+
[-1, 1, Conv, [512, 3, 2]], # 7-P4/16
|
20 |
+
[-1, 8, Bottleneck, [512]],
|
21 |
+
[-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
|
22 |
+
[-1, 4, Bottleneck, [1024]], # 10
|
23 |
+
]
|
24 |
+
|
25 |
+
# YOLOv3-SPP head
|
26 |
+
head:
|
27 |
+
[[-1, 1, Bottleneck, [1024, False]],
|
28 |
+
[-1, 1, SPP, [512, [5, 9, 13]]],
|
29 |
+
[-1, 1, Conv, [1024, 3, 1]],
|
30 |
+
[-1, 1, Conv, [512, 1, 1]],
|
31 |
+
[-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
|
32 |
+
|
33 |
+
[-2, 1, Conv, [256, 1, 1]],
|
34 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
35 |
+
[[-1, 8], 1, Concat, [1]], # cat backbone P4
|
36 |
+
[-1, 1, Bottleneck, [512, False]],
|
37 |
+
[-1, 1, Bottleneck, [512, False]],
|
38 |
+
[-1, 1, Conv, [256, 1, 1]],
|
39 |
+
[-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
|
40 |
+
|
41 |
+
[-2, 1, Conv, [128, 1, 1]],
|
42 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
43 |
+
[[-1, 6], 1, Concat, [1]], # cat backbone P3
|
44 |
+
[-1, 1, Bottleneck, [256, False]],
|
45 |
+
[-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
|
46 |
+
|
47 |
+
[[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
|
48 |
+
]
|
ultralytics/models/v3/yolov3-tiny.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv3-tiny object detection model with P4-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
depth_multiple: 1.0 # model depth multiple
|
7 |
+
width_multiple: 1.0 # layer channel multiple
|
8 |
+
|
9 |
+
# YOLOv3-tiny backbone
|
10 |
+
backbone:
|
11 |
+
# [from, number, module, args]
|
12 |
+
[[-1, 1, Conv, [16, 3, 1]], # 0
|
13 |
+
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2
|
14 |
+
[-1, 1, Conv, [32, 3, 1]],
|
15 |
+
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4
|
16 |
+
[-1, 1, Conv, [64, 3, 1]],
|
17 |
+
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8
|
18 |
+
[-1, 1, Conv, [128, 3, 1]],
|
19 |
+
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16
|
20 |
+
[-1, 1, Conv, [256, 3, 1]],
|
21 |
+
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32
|
22 |
+
[-1, 1, Conv, [512, 3, 1]],
|
23 |
+
[-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11
|
24 |
+
[-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12
|
25 |
+
]
|
26 |
+
|
27 |
+
# YOLOv3-tiny head
|
28 |
+
head:
|
29 |
+
[[-1, 1, Conv, [1024, 3, 1]],
|
30 |
+
[-1, 1, Conv, [256, 1, 1]],
|
31 |
+
[-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large)
|
32 |
+
|
33 |
+
[-2, 1, Conv, [128, 1, 1]],
|
34 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
35 |
+
[[-1, 8], 1, Concat, [1]], # cat backbone P4
|
36 |
+
[-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium)
|
37 |
+
|
38 |
+
[[19, 15], 1, Detect, [nc]], # Detect(P4, P5)
|
39 |
+
]
|
ultralytics/models/v3/yolov3.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv3 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
depth_multiple: 1.0 # model depth multiple
|
7 |
+
width_multiple: 1.0 # layer channel multiple
|
8 |
+
|
9 |
+
# darknet53 backbone
|
10 |
+
backbone:
|
11 |
+
# [from, number, module, args]
|
12 |
+
[[-1, 1, Conv, [32, 3, 1]], # 0
|
13 |
+
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2
|
14 |
+
[-1, 1, Bottleneck, [64]],
|
15 |
+
[-1, 1, Conv, [128, 3, 2]], # 3-P2/4
|
16 |
+
[-1, 2, Bottleneck, [128]],
|
17 |
+
[-1, 1, Conv, [256, 3, 2]], # 5-P3/8
|
18 |
+
[-1, 8, Bottleneck, [256]],
|
19 |
+
[-1, 1, Conv, [512, 3, 2]], # 7-P4/16
|
20 |
+
[-1, 8, Bottleneck, [512]],
|
21 |
+
[-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
|
22 |
+
[-1, 4, Bottleneck, [1024]], # 10
|
23 |
+
]
|
24 |
+
|
25 |
+
# YOLOv3 head
|
26 |
+
head:
|
27 |
+
[[-1, 1, Bottleneck, [1024, False]],
|
28 |
+
[-1, 1, Conv, [512, 1, 1]],
|
29 |
+
[-1, 1, Conv, [1024, 3, 1]],
|
30 |
+
[-1, 1, Conv, [512, 1, 1]],
|
31 |
+
[-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
|
32 |
+
|
33 |
+
[-2, 1, Conv, [256, 1, 1]],
|
34 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
35 |
+
[[-1, 8], 1, Concat, [1]], # cat backbone P4
|
36 |
+
[-1, 1, Bottleneck, [512, False]],
|
37 |
+
[-1, 1, Bottleneck, [512, False]],
|
38 |
+
[-1, 1, Conv, [256, 1, 1]],
|
39 |
+
[-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
|
40 |
+
|
41 |
+
[-2, 1, Conv, [128, 1, 1]],
|
42 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
43 |
+
[[-1, 6], 1, Concat, [1]], # cat backbone P3
|
44 |
+
[-1, 1, Bottleneck, [256, False]],
|
45 |
+
[-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
|
46 |
+
|
47 |
+
[[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
|
48 |
+
]
|
ultralytics/models/v5/yolov5-p6.yaml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv5 object detection model with P3-P6 outputs. For details see https://docs.ultralytics.com/models/yolov5
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov5n-p6.yaml' will call yolov5-p6.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024]
|
9 |
+
s: [0.33, 0.50, 1024]
|
10 |
+
m: [0.67, 0.75, 1024]
|
11 |
+
l: [1.00, 1.00, 1024]
|
12 |
+
x: [1.33, 1.25, 1024]
|
13 |
+
|
14 |
+
# YOLOv5 v6.0 backbone
|
15 |
+
backbone:
|
16 |
+
# [from, number, module, args]
|
17 |
+
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
18 |
+
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
19 |
+
[-1, 3, C3, [128]],
|
20 |
+
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
21 |
+
[-1, 6, C3, [256]],
|
22 |
+
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
23 |
+
[-1, 9, C3, [512]],
|
24 |
+
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
|
25 |
+
[-1, 3, C3, [768]],
|
26 |
+
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
|
27 |
+
[-1, 3, C3, [1024]],
|
28 |
+
[-1, 1, SPPF, [1024, 5]], # 11
|
29 |
+
]
|
30 |
+
|
31 |
+
# YOLOv5 v6.0 head
|
32 |
+
head:
|
33 |
+
[[-1, 1, Conv, [768, 1, 1]],
|
34 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
35 |
+
[[-1, 8], 1, Concat, [1]], # cat backbone P5
|
36 |
+
[-1, 3, C3, [768, False]], # 15
|
37 |
+
|
38 |
+
[-1, 1, Conv, [512, 1, 1]],
|
39 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
40 |
+
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
41 |
+
[-1, 3, C3, [512, False]], # 19
|
42 |
+
|
43 |
+
[-1, 1, Conv, [256, 1, 1]],
|
44 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
45 |
+
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
46 |
+
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
|
47 |
+
|
48 |
+
[-1, 1, Conv, [256, 3, 2]],
|
49 |
+
[[-1, 20], 1, Concat, [1]], # cat head P4
|
50 |
+
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
|
51 |
+
|
52 |
+
[-1, 1, Conv, [512, 3, 2]],
|
53 |
+
[[-1, 16], 1, Concat, [1]], # cat head P5
|
54 |
+
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
|
55 |
+
|
56 |
+
[-1, 1, Conv, [768, 3, 2]],
|
57 |
+
[[-1, 12], 1, Concat, [1]], # cat head P6
|
58 |
+
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
|
59 |
+
|
60 |
+
[[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6)
|
61 |
+
]
|
ultralytics/models/v5/yolov5.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv5 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov5
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov5n.yaml' will call yolov5.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024]
|
9 |
+
s: [0.33, 0.50, 1024]
|
10 |
+
m: [0.67, 0.75, 1024]
|
11 |
+
l: [1.00, 1.00, 1024]
|
12 |
+
x: [1.33, 1.25, 1024]
|
13 |
+
|
14 |
+
# YOLOv5 v6.0 backbone
|
15 |
+
backbone:
|
16 |
+
# [from, number, module, args]
|
17 |
+
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
18 |
+
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
19 |
+
[-1, 3, C3, [128]],
|
20 |
+
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
21 |
+
[-1, 6, C3, [256]],
|
22 |
+
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
23 |
+
[-1, 9, C3, [512]],
|
24 |
+
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
25 |
+
[-1, 3, C3, [1024]],
|
26 |
+
[-1, 1, SPPF, [1024, 5]], # 9
|
27 |
+
]
|
28 |
+
|
29 |
+
# YOLOv5 v6.0 head
|
30 |
+
head:
|
31 |
+
[[-1, 1, Conv, [512, 1, 1]],
|
32 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
33 |
+
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
34 |
+
[-1, 3, C3, [512, False]], # 13
|
35 |
+
|
36 |
+
[-1, 1, Conv, [256, 1, 1]],
|
37 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
38 |
+
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
39 |
+
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
40 |
+
|
41 |
+
[-1, 1, Conv, [256, 3, 2]],
|
42 |
+
[[-1, 14], 1, Concat, [1]], # cat head P4
|
43 |
+
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
44 |
+
|
45 |
+
[-1, 1, Conv, [512, 3, 2]],
|
46 |
+
[[-1, 10], 1, Concat, [1]], # cat head P5
|
47 |
+
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
48 |
+
|
49 |
+
[[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
|
50 |
+
]
|
ultralytics/models/v6/yolov6.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/models/yolov6
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
act: nn.ReLU()
|
6 |
+
nc: 80 # number of classes
|
7 |
+
scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n'
|
8 |
+
# [depth, width, max_channels]
|
9 |
+
n: [0.33, 0.25, 1024]
|
10 |
+
s: [0.33, 0.50, 1024]
|
11 |
+
m: [0.67, 0.75, 768]
|
12 |
+
l: [1.00, 1.00, 512]
|
13 |
+
x: [1.00, 1.25, 512]
|
14 |
+
|
15 |
+
# YOLOv6-3.0s backbone
|
16 |
+
backbone:
|
17 |
+
# [from, repeats, module, args]
|
18 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
19 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
20 |
+
- [-1, 6, Conv, [128, 3, 1]]
|
21 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
22 |
+
- [-1, 12, Conv, [256, 3, 1]]
|
23 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
24 |
+
- [-1, 18, Conv, [512, 3, 1]]
|
25 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
26 |
+
- [-1, 6, Conv, [1024, 3, 1]]
|
27 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
28 |
+
|
29 |
+
# YOLOv6-3.0s head
|
30 |
+
head:
|
31 |
+
- [-1, 1, Conv, [256, 1, 1]]
|
32 |
+
- [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]]
|
33 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
34 |
+
- [-1, 1, Conv, [256, 3, 1]]
|
35 |
+
- [-1, 9, Conv, [256, 3, 1]] # 14
|
36 |
+
|
37 |
+
- [-1, 1, Conv, [128, 1, 1]]
|
38 |
+
- [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]]
|
39 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
40 |
+
- [-1, 1, Conv, [128, 3, 1]]
|
41 |
+
- [-1, 9, Conv, [128, 3, 1]] # 19
|
42 |
+
|
43 |
+
- [-1, 1, Conv, [128, 3, 2]]
|
44 |
+
- [[-1, 15], 1, Concat, [1]] # cat head P4
|
45 |
+
- [-1, 1, Conv, [256, 3, 1]]
|
46 |
+
- [-1, 9, Conv, [256, 3, 1]] # 23
|
47 |
+
|
48 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
49 |
+
- [[-1, 10], 1, Concat, [1]] # cat head P5
|
50 |
+
- [-1, 1, Conv, [512, 3, 1]]
|
51 |
+
- [-1, 9, Conv, [512, 3, 1]] # 27
|
52 |
+
|
53 |
+
- [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5)
|
ultralytics/models/v8/yolov8-cls.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 1000 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024]
|
9 |
+
s: [0.33, 0.50, 1024]
|
10 |
+
m: [0.67, 0.75, 1024]
|
11 |
+
l: [1.00, 1.00, 1024]
|
12 |
+
x: [1.00, 1.25, 1024]
|
13 |
+
|
14 |
+
# YOLOv8.0n backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
|
27 |
+
# YOLOv8.0n head
|
28 |
+
head:
|
29 |
+
- [-1, 1, Classify, [nc]] # Classify
|
ultralytics/models/v8/yolov8-p2.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8 object detection model with P2-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024]
|
9 |
+
s: [0.33, 0.50, 1024]
|
10 |
+
m: [0.67, 0.75, 768]
|
11 |
+
l: [1.00, 1.00, 512]
|
12 |
+
x: [1.00, 1.25, 512]
|
13 |
+
|
14 |
+
# YOLOv8.0 backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
27 |
+
|
28 |
+
# YOLOv8.0-p2 head
|
29 |
+
head:
|
30 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
31 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
32 |
+
- [-1, 3, C2f, [512]] # 12
|
33 |
+
|
34 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
35 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
36 |
+
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
37 |
+
|
38 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
39 |
+
- [[-1, 2], 1, Concat, [1]] # cat backbone P2
|
40 |
+
- [-1, 3, C2f, [128]] # 18 (P2/4-xsmall)
|
41 |
+
|
42 |
+
- [-1, 1, Conv, [128, 3, 2]]
|
43 |
+
- [[-1, 15], 1, Concat, [1]] # cat head P3
|
44 |
+
- [-1, 3, C2f, [256]] # 21 (P3/8-small)
|
45 |
+
|
46 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
47 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
48 |
+
- [-1, 3, C2f, [512]] # 24 (P4/16-medium)
|
49 |
+
|
50 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
51 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
52 |
+
- [-1, 3, C2f, [1024]] # 27 (P5/32-large)
|
53 |
+
|
54 |
+
- [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5)
|
ultralytics/models/v8/yolov8-p6.yaml
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8 object detection model with P3-P6 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024]
|
9 |
+
s: [0.33, 0.50, 1024]
|
10 |
+
m: [0.67, 0.75, 768]
|
11 |
+
l: [1.00, 1.00, 512]
|
12 |
+
x: [1.00, 1.25, 512]
|
13 |
+
|
14 |
+
# YOLOv8.0x6 backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [768, True]]
|
26 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64
|
27 |
+
- [-1, 3, C2f, [1024, True]]
|
28 |
+
- [-1, 1, SPPF, [1024, 5]] # 11
|
29 |
+
|
30 |
+
# YOLOv8.0x6 head
|
31 |
+
head:
|
32 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
33 |
+
- [[-1, 8], 1, Concat, [1]] # cat backbone P5
|
34 |
+
- [-1, 3, C2, [768, False]] # 14
|
35 |
+
|
36 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
37 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
38 |
+
- [-1, 3, C2, [512, False]] # 17
|
39 |
+
|
40 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
41 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
42 |
+
- [-1, 3, C2, [256, False]] # 20 (P3/8-small)
|
43 |
+
|
44 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
45 |
+
- [[-1, 17], 1, Concat, [1]] # cat head P4
|
46 |
+
- [-1, 3, C2, [512, False]] # 23 (P4/16-medium)
|
47 |
+
|
48 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
49 |
+
- [[-1, 14], 1, Concat, [1]] # cat head P5
|
50 |
+
- [-1, 3, C2, [768, False]] # 26 (P5/32-large)
|
51 |
+
|
52 |
+
- [-1, 1, Conv, [768, 3, 2]]
|
53 |
+
- [[-1, 11], 1, Concat, [1]] # cat head P6
|
54 |
+
- [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge)
|
55 |
+
|
56 |
+
- [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6)
|
ultralytics/models/v8/yolov8-pose-p6.yaml
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 1 # number of classes
|
6 |
+
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
7 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n'
|
8 |
+
# [depth, width, max_channels]
|
9 |
+
n: [0.33, 0.25, 1024]
|
10 |
+
s: [0.33, 0.50, 1024]
|
11 |
+
m: [0.67, 0.75, 768]
|
12 |
+
l: [1.00, 1.00, 512]
|
13 |
+
x: [1.00, 1.25, 512]
|
14 |
+
|
15 |
+
# YOLOv8.0x6 backbone
|
16 |
+
backbone:
|
17 |
+
# [from, repeats, module, args]
|
18 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
19 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
20 |
+
- [-1, 3, C2f, [128, True]]
|
21 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
22 |
+
- [-1, 6, C2f, [256, True]]
|
23 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
24 |
+
- [-1, 6, C2f, [512, True]]
|
25 |
+
- [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
|
26 |
+
- [-1, 3, C2f, [768, True]]
|
27 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64
|
28 |
+
- [-1, 3, C2f, [1024, True]]
|
29 |
+
- [-1, 1, SPPF, [1024, 5]] # 11
|
30 |
+
|
31 |
+
# YOLOv8.0x6 head
|
32 |
+
head:
|
33 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
34 |
+
- [[-1, 8], 1, Concat, [1]] # cat backbone P5
|
35 |
+
- [-1, 3, C2, [768, False]] # 14
|
36 |
+
|
37 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
38 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
39 |
+
- [-1, 3, C2, [512, False]] # 17
|
40 |
+
|
41 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
42 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
43 |
+
- [-1, 3, C2, [256, False]] # 20 (P3/8-small)
|
44 |
+
|
45 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
46 |
+
- [[-1, 17], 1, Concat, [1]] # cat head P4
|
47 |
+
- [-1, 3, C2, [512, False]] # 23 (P4/16-medium)
|
48 |
+
|
49 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
50 |
+
- [[-1, 14], 1, Concat, [1]] # cat head P5
|
51 |
+
- [-1, 3, C2, [768, False]] # 26 (P5/32-large)
|
52 |
+
|
53 |
+
- [-1, 1, Conv, [768, 3, 2]]
|
54 |
+
- [[-1, 11], 1, Concat, [1]] # cat head P6
|
55 |
+
- [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge)
|
56 |
+
|
57 |
+
- [[20, 23, 26, 29], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5, P6)
|
ultralytics/models/v8/yolov8-pose.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 1 # number of classes
|
6 |
+
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
7 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n'
|
8 |
+
# [depth, width, max_channels]
|
9 |
+
n: [0.33, 0.25, 1024]
|
10 |
+
s: [0.33, 0.50, 1024]
|
11 |
+
m: [0.67, 0.75, 768]
|
12 |
+
l: [1.00, 1.00, 512]
|
13 |
+
x: [1.00, 1.25, 512]
|
14 |
+
|
15 |
+
# YOLOv8.0n backbone
|
16 |
+
backbone:
|
17 |
+
# [from, repeats, module, args]
|
18 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
19 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
20 |
+
- [-1, 3, C2f, [128, True]]
|
21 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
22 |
+
- [-1, 6, C2f, [256, True]]
|
23 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
24 |
+
- [-1, 6, C2f, [512, True]]
|
25 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
26 |
+
- [-1, 3, C2f, [1024, True]]
|
27 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
28 |
+
|
29 |
+
# YOLOv8.0n head
|
30 |
+
head:
|
31 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
32 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
33 |
+
- [-1, 3, C2f, [512]] # 12
|
34 |
+
|
35 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
36 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
37 |
+
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
38 |
+
|
39 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
40 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
41 |
+
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
|
42 |
+
|
43 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
44 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
45 |
+
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
|
46 |
+
|
47 |
+
- [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5)
|
ultralytics/models/v8/yolov8-seg.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024]
|
9 |
+
s: [0.33, 0.50, 1024]
|
10 |
+
m: [0.67, 0.75, 768]
|
11 |
+
l: [1.00, 1.00, 512]
|
12 |
+
x: [1.00, 1.25, 512]
|
13 |
+
|
14 |
+
# YOLOv8.0n backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
27 |
+
|
28 |
+
# YOLOv8.0n head
|
29 |
+
head:
|
30 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
31 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
32 |
+
- [-1, 3, C2f, [512]] # 12
|
33 |
+
|
34 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
35 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
36 |
+
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
37 |
+
|
38 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
39 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
40 |
+
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
|
41 |
+
|
42 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
43 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
44 |
+
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
|
45 |
+
|
46 |
+
- [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5)
|
ultralytics/models/v8/yolov8.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
|
9 |
+
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
|
10 |
+
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
|
11 |
+
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
|
12 |
+
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
|
13 |
+
|
14 |
+
# YOLOv8.0n backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
27 |
+
|
28 |
+
# YOLOv8.0n head
|
29 |
+
head:
|
30 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
31 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
32 |
+
- [-1, 3, C2f, [512]] # 12
|
33 |
+
|
34 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
35 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
36 |
+
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
37 |
+
|
38 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
39 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
40 |
+
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
|
41 |
+
|
42 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
43 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
44 |
+
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
|
45 |
+
|
46 |
+
- [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
|
ultralytics/nn/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
|
4 |
+
attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load,
|
5 |
+
yaml_model_load)
|
6 |
+
|
7 |
+
__all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task',
|
8 |
+
'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel',
|
9 |
+
'BaseModel')
|
ultralytics/nn/autobackend.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import ast
|
4 |
+
import contextlib
|
5 |
+
import json
|
6 |
+
import platform
|
7 |
+
import zipfile
|
8 |
+
from collections import OrderedDict, namedtuple
|
9 |
+
from pathlib import Path
|
10 |
+
from urllib.parse import urlparse
|
11 |
+
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
from ultralytics.yolo.utils import LINUX, LOGGER, ROOT, yaml_load
|
19 |
+
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version, check_yaml
|
20 |
+
from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
|
21 |
+
from ultralytics.yolo.utils.ops import xywh2xyxy
|
22 |
+
|
23 |
+
|
24 |
+
def check_class_names(names):
|
25 |
+
"""Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts."""
|
26 |
+
if isinstance(names, list): # names is a list
|
27 |
+
names = dict(enumerate(names)) # convert to dict
|
28 |
+
if isinstance(names, dict):
|
29 |
+
# Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'
|
30 |
+
names = {int(k): str(v) for k, v in names.items()}
|
31 |
+
n = len(names)
|
32 |
+
if max(names.keys()) >= n:
|
33 |
+
raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
|
34 |
+
f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
|
35 |
+
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
|
36 |
+
map = yaml_load(ROOT / 'datasets/ImageNet.yaml')['map'] # human-readable names
|
37 |
+
names = {k: map[v] for k, v in names.items()}
|
38 |
+
return names
|
39 |
+
|
40 |
+
|
41 |
+
class AutoBackend(nn.Module):
|
42 |
+
|
43 |
+
def __init__(self,
|
44 |
+
weights='yolov8n.pt',
|
45 |
+
device=torch.device('cpu'),
|
46 |
+
dnn=False,
|
47 |
+
data=None,
|
48 |
+
fp16=False,
|
49 |
+
fuse=True,
|
50 |
+
verbose=True):
|
51 |
+
"""
|
52 |
+
MultiBackend class for python inference on various platforms using Ultralytics YOLO.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
weights (str): The path to the weights file. Default: 'yolov8n.pt'
|
56 |
+
device (torch.device): The device to run the model on.
|
57 |
+
dnn (bool): Use OpenCV's DNN module for inference if True, defaults to False.
|
58 |
+
data (str), (Path): Additional data.yaml file for class names, optional
|
59 |
+
fp16 (bool): If True, use half precision. Default: False
|
60 |
+
fuse (bool): Whether to fuse the model or not. Default: True
|
61 |
+
verbose (bool): Whether to run in verbose mode or not. Default: True
|
62 |
+
|
63 |
+
Supported formats and their naming conventions:
|
64 |
+
| Format | Suffix |
|
65 |
+
|-----------------------|------------------|
|
66 |
+
| PyTorch | *.pt |
|
67 |
+
| TorchScript | *.torchscript |
|
68 |
+
| ONNX Runtime | *.onnx |
|
69 |
+
| ONNX OpenCV DNN | *.onnx dnn=True |
|
70 |
+
| OpenVINO | *.xml |
|
71 |
+
| CoreML | *.mlmodel |
|
72 |
+
| TensorRT | *.engine |
|
73 |
+
| TensorFlow SavedModel | *_saved_model |
|
74 |
+
| TensorFlow GraphDef | *.pb |
|
75 |
+
| TensorFlow Lite | *.tflite |
|
76 |
+
| TensorFlow Edge TPU | *_edgetpu.tflite |
|
77 |
+
| PaddlePaddle | *_paddle_model |
|
78 |
+
"""
|
79 |
+
super().__init__()
|
80 |
+
w = str(weights[0] if isinstance(weights, list) else weights)
|
81 |
+
nn_module = isinstance(weights, torch.nn.Module)
|
82 |
+
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
|
83 |
+
fp16 &= pt or jit or onnx or engine or nn_module or triton # FP16
|
84 |
+
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
|
85 |
+
stride = 32 # default stride
|
86 |
+
model, metadata = None, None
|
87 |
+
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
88 |
+
if not (pt or triton or nn_module):
|
89 |
+
w = attempt_download_asset(w) # download if not local
|
90 |
+
|
91 |
+
# NOTE: special case: in-memory pytorch model
|
92 |
+
if nn_module:
|
93 |
+
model = weights.to(device)
|
94 |
+
model = model.fuse(verbose=verbose) if fuse else model
|
95 |
+
if hasattr(model, 'kpt_shape'):
|
96 |
+
kpt_shape = model.kpt_shape # pose-only
|
97 |
+
stride = max(int(model.stride.max()), 32) # model stride
|
98 |
+
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
99 |
+
model.half() if fp16 else model.float()
|
100 |
+
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
101 |
+
pt = True
|
102 |
+
elif pt: # PyTorch
|
103 |
+
from ultralytics.nn.tasks import attempt_load_weights
|
104 |
+
model = attempt_load_weights(weights if isinstance(weights, list) else w,
|
105 |
+
device=device,
|
106 |
+
inplace=True,
|
107 |
+
fuse=fuse)
|
108 |
+
if hasattr(model, 'kpt_shape'):
|
109 |
+
kpt_shape = model.kpt_shape # pose-only
|
110 |
+
stride = max(int(model.stride.max()), 32) # model stride
|
111 |
+
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
112 |
+
model.half() if fp16 else model.float()
|
113 |
+
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
114 |
+
elif jit: # TorchScript
|
115 |
+
LOGGER.info(f'Loading {w} for TorchScript inference...')
|
116 |
+
extra_files = {'config.txt': ''} # model metadata
|
117 |
+
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
|
118 |
+
model.half() if fp16 else model.float()
|
119 |
+
if extra_files['config.txt']: # load metadata dict
|
120 |
+
metadata = json.loads(extra_files['config.txt'], object_hook=lambda x: dict(x.items()))
|
121 |
+
elif dnn: # ONNX OpenCV DNN
|
122 |
+
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
|
123 |
+
check_requirements('opencv-python>=4.5.4')
|
124 |
+
net = cv2.dnn.readNetFromONNX(w)
|
125 |
+
elif onnx: # ONNX Runtime
|
126 |
+
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
127 |
+
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
128 |
+
import onnxruntime
|
129 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
130 |
+
session = onnxruntime.InferenceSession(w, providers=providers)
|
131 |
+
output_names = [x.name for x in session.get_outputs()]
|
132 |
+
metadata = session.get_modelmeta().custom_metadata_map # metadata
|
133 |
+
elif xml: # OpenVINO
|
134 |
+
LOGGER.info(f'Loading {w} for OpenVINO inference...')
|
135 |
+
check_requirements('openvino') # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
136 |
+
from openvino.runtime import Core, Layout, get_batch # noqa
|
137 |
+
ie = Core()
|
138 |
+
w = Path(w)
|
139 |
+
if not w.is_file(): # if not *.xml
|
140 |
+
w = next(w.glob('*.xml')) # get *.xml file from *_openvino_model dir
|
141 |
+
network = ie.read_model(model=str(w), weights=w.with_suffix('.bin'))
|
142 |
+
if network.get_parameters()[0].get_layout().empty:
|
143 |
+
network.get_parameters()[0].set_layout(Layout('NCHW'))
|
144 |
+
batch_dim = get_batch(network)
|
145 |
+
if batch_dim.is_static:
|
146 |
+
batch_size = batch_dim.get_length()
|
147 |
+
executable_network = ie.compile_model(network, device_name='CPU') # device_name="MYRIAD" for NCS2
|
148 |
+
metadata = w.parent / 'metadata.yaml'
|
149 |
+
elif engine: # TensorRT
|
150 |
+
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
151 |
+
try:
|
152 |
+
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
|
153 |
+
except ImportError:
|
154 |
+
if LINUX:
|
155 |
+
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
|
156 |
+
import tensorrt as trt # noqa
|
157 |
+
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
158 |
+
if device.type == 'cpu':
|
159 |
+
device = torch.device('cuda:0')
|
160 |
+
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
161 |
+
logger = trt.Logger(trt.Logger.INFO)
|
162 |
+
# Read file
|
163 |
+
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
164 |
+
meta_len = int.from_bytes(f.read(4), byteorder='little') # read metadata length
|
165 |
+
metadata = json.loads(f.read(meta_len).decode('utf-8')) # read metadata
|
166 |
+
model = runtime.deserialize_cuda_engine(f.read()) # read engine
|
167 |
+
context = model.create_execution_context()
|
168 |
+
bindings = OrderedDict()
|
169 |
+
output_names = []
|
170 |
+
fp16 = False # default updated below
|
171 |
+
dynamic = False
|
172 |
+
for i in range(model.num_bindings):
|
173 |
+
name = model.get_binding_name(i)
|
174 |
+
dtype = trt.nptype(model.get_binding_dtype(i))
|
175 |
+
if model.binding_is_input(i):
|
176 |
+
if -1 in tuple(model.get_binding_shape(i)): # dynamic
|
177 |
+
dynamic = True
|
178 |
+
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
|
179 |
+
if dtype == np.float16:
|
180 |
+
fp16 = True
|
181 |
+
else: # output
|
182 |
+
output_names.append(name)
|
183 |
+
shape = tuple(context.get_binding_shape(i))
|
184 |
+
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
185 |
+
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
186 |
+
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
187 |
+
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
|
188 |
+
elif coreml: # CoreML
|
189 |
+
LOGGER.info(f'Loading {w} for CoreML inference...')
|
190 |
+
import coremltools as ct
|
191 |
+
model = ct.models.MLModel(w)
|
192 |
+
metadata = dict(model.user_defined_metadata)
|
193 |
+
elif saved_model: # TF SavedModel
|
194 |
+
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
|
195 |
+
import tensorflow as tf
|
196 |
+
keras = False # assume TF1 saved_model
|
197 |
+
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
198 |
+
metadata = Path(w) / 'metadata.yaml'
|
199 |
+
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
200 |
+
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
201 |
+
import tensorflow as tf
|
202 |
+
|
203 |
+
from ultralytics.yolo.engine.exporter import gd_outputs
|
204 |
+
|
205 |
+
def wrap_frozen_graph(gd, inputs, outputs):
|
206 |
+
"""Wrap frozen graphs for deployment."""
|
207 |
+
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
|
208 |
+
ge = x.graph.as_graph_element
|
209 |
+
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
210 |
+
|
211 |
+
gd = tf.Graph().as_graph_def() # TF GraphDef
|
212 |
+
with open(w, 'rb') as f:
|
213 |
+
gd.ParseFromString(f.read())
|
214 |
+
frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
|
215 |
+
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
216 |
+
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
217 |
+
from tflite_runtime.interpreter import Interpreter, load_delegate
|
218 |
+
except ImportError:
|
219 |
+
import tensorflow as tf
|
220 |
+
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
|
221 |
+
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
|
222 |
+
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
223 |
+
delegate = {
|
224 |
+
'Linux': 'libedgetpu.so.1',
|
225 |
+
'Darwin': 'libedgetpu.1.dylib',
|
226 |
+
'Windows': 'edgetpu.dll'}[platform.system()]
|
227 |
+
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
|
228 |
+
else: # TFLite
|
229 |
+
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
|
230 |
+
interpreter = Interpreter(model_path=w) # load TFLite model
|
231 |
+
interpreter.allocate_tensors() # allocate
|
232 |
+
input_details = interpreter.get_input_details() # inputs
|
233 |
+
output_details = interpreter.get_output_details() # outputs
|
234 |
+
# Load metadata
|
235 |
+
with contextlib.suppress(zipfile.BadZipFile):
|
236 |
+
with zipfile.ZipFile(w, 'r') as model:
|
237 |
+
meta_file = model.namelist()[0]
|
238 |
+
metadata = ast.literal_eval(model.read(meta_file).decode('utf-8'))
|
239 |
+
elif tfjs: # TF.js
|
240 |
+
raise NotImplementedError('YOLOv8 TF.js inference is not supported')
|
241 |
+
elif paddle: # PaddlePaddle
|
242 |
+
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
|
243 |
+
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
|
244 |
+
import paddle.inference as pdi # noqa
|
245 |
+
w = Path(w)
|
246 |
+
if not w.is_file(): # if not *.pdmodel
|
247 |
+
w = next(w.rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
|
248 |
+
config = pdi.Config(str(w), str(w.with_suffix('.pdiparams')))
|
249 |
+
if cuda:
|
250 |
+
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
|
251 |
+
predictor = pdi.create_predictor(config)
|
252 |
+
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
253 |
+
output_names = predictor.get_output_names()
|
254 |
+
metadata = w.parents[1] / 'metadata.yaml'
|
255 |
+
elif triton: # NVIDIA Triton Inference Server
|
256 |
+
LOGGER.info('Triton Inference Server not supported...')
|
257 |
+
'''
|
258 |
+
TODO:
|
259 |
+
check_requirements('tritonclient[all]')
|
260 |
+
from utils.triton import TritonRemoteModel
|
261 |
+
model = TritonRemoteModel(url=w)
|
262 |
+
nhwc = model.runtime.startswith("tensorflow")
|
263 |
+
'''
|
264 |
+
else:
|
265 |
+
from ultralytics.yolo.engine.exporter import export_formats
|
266 |
+
raise TypeError(f"model='{w}' is not a supported model format. "
|
267 |
+
'See https://docs.ultralytics.com/modes/predict for help.'
|
268 |
+
f'\n\n{export_formats()}')
|
269 |
+
|
270 |
+
# Load external metadata YAML
|
271 |
+
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
|
272 |
+
metadata = yaml_load(metadata)
|
273 |
+
if metadata:
|
274 |
+
for k, v in metadata.items():
|
275 |
+
if k in ('stride', 'batch'):
|
276 |
+
metadata[k] = int(v)
|
277 |
+
elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str):
|
278 |
+
metadata[k] = eval(v)
|
279 |
+
stride = metadata['stride']
|
280 |
+
task = metadata['task']
|
281 |
+
batch = metadata['batch']
|
282 |
+
imgsz = metadata['imgsz']
|
283 |
+
names = metadata['names']
|
284 |
+
kpt_shape = metadata.get('kpt_shape')
|
285 |
+
elif not (pt or triton or nn_module):
|
286 |
+
LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
|
287 |
+
|
288 |
+
# Check names
|
289 |
+
if 'names' not in locals(): # names missing
|
290 |
+
names = self._apply_default_class_names(data)
|
291 |
+
names = check_class_names(names)
|
292 |
+
|
293 |
+
self.__dict__.update(locals()) # assign all variables to self
|
294 |
+
|
295 |
+
def forward(self, im, augment=False, visualize=False):
|
296 |
+
"""
|
297 |
+
Runs inference on the YOLOv8 MultiBackend model.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
im (torch.Tensor): The image tensor to perform inference on.
|
301 |
+
augment (bool): whether to perform data augmentation during inference, defaults to False
|
302 |
+
visualize (bool): whether to visualize the output predictions, defaults to False
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
(tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
|
306 |
+
"""
|
307 |
+
b, ch, h, w = im.shape # batch, channel, height, width
|
308 |
+
if self.fp16 and im.dtype != torch.float16:
|
309 |
+
im = im.half() # to FP16
|
310 |
+
if self.nhwc:
|
311 |
+
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
|
312 |
+
|
313 |
+
if self.pt or self.nn_module: # PyTorch
|
314 |
+
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
|
315 |
+
elif self.jit: # TorchScript
|
316 |
+
y = self.model(im)
|
317 |
+
elif self.dnn: # ONNX OpenCV DNN
|
318 |
+
im = im.cpu().numpy() # torch to numpy
|
319 |
+
self.net.setInput(im)
|
320 |
+
y = self.net.forward()
|
321 |
+
elif self.onnx: # ONNX Runtime
|
322 |
+
im = im.cpu().numpy() # torch to numpy
|
323 |
+
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
324 |
+
elif self.xml: # OpenVINO
|
325 |
+
im = im.cpu().numpy() # FP32
|
326 |
+
y = list(self.executable_network([im]).values())
|
327 |
+
elif self.engine: # TensorRT
|
328 |
+
if self.dynamic and im.shape != self.bindings['images'].shape:
|
329 |
+
i = self.model.get_binding_index('images')
|
330 |
+
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
|
331 |
+
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
|
332 |
+
for name in self.output_names:
|
333 |
+
i = self.model.get_binding_index(name)
|
334 |
+
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
|
335 |
+
s = self.bindings['images'].shape
|
336 |
+
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
|
337 |
+
self.binding_addrs['images'] = int(im.data_ptr())
|
338 |
+
self.context.execute_v2(list(self.binding_addrs.values()))
|
339 |
+
y = [self.bindings[x].data for x in sorted(self.output_names)]
|
340 |
+
elif self.coreml: # CoreML
|
341 |
+
im = im[0].cpu().numpy()
|
342 |
+
im_pil = Image.fromarray((im * 255).astype('uint8'))
|
343 |
+
# im = im.resize((192, 320), Image.ANTIALIAS)
|
344 |
+
y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized
|
345 |
+
if 'confidence' in y:
|
346 |
+
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
|
347 |
+
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
|
348 |
+
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
|
349 |
+
elif len(y) == 1: # classification model
|
350 |
+
y = list(y.values())
|
351 |
+
elif len(y) == 2: # segmentation model
|
352 |
+
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
|
353 |
+
elif self.paddle: # PaddlePaddle
|
354 |
+
im = im.cpu().numpy().astype(np.float32)
|
355 |
+
self.input_handle.copy_from_cpu(im)
|
356 |
+
self.predictor.run()
|
357 |
+
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
|
358 |
+
elif self.triton: # NVIDIA Triton Inference Server
|
359 |
+
y = self.model(im)
|
360 |
+
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
361 |
+
im = im.cpu().numpy()
|
362 |
+
if self.saved_model: # SavedModel
|
363 |
+
y = self.model(im, training=False) if self.keras else self.model(im)
|
364 |
+
if not isinstance(y, list):
|
365 |
+
y = [y]
|
366 |
+
elif self.pb: # GraphDef
|
367 |
+
y = self.frozen_func(x=self.tf.constant(im))
|
368 |
+
if len(y) == 2 and len(self.names) == 999: # segments and names not defined
|
369 |
+
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
|
370 |
+
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
|
371 |
+
self.names = {i: f'class{i}' for i in range(nc)}
|
372 |
+
else: # Lite or Edge TPU
|
373 |
+
input = self.input_details[0]
|
374 |
+
int8 = input['dtype'] == np.int8 # is TFLite quantized int8 model
|
375 |
+
if int8:
|
376 |
+
scale, zero_point = input['quantization']
|
377 |
+
im = (im / scale + zero_point).astype(np.int8) # de-scale
|
378 |
+
self.interpreter.set_tensor(input['index'], im)
|
379 |
+
self.interpreter.invoke()
|
380 |
+
y = []
|
381 |
+
for output in self.output_details:
|
382 |
+
x = self.interpreter.get_tensor(output['index'])
|
383 |
+
if int8:
|
384 |
+
scale, zero_point = output['quantization']
|
385 |
+
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
386 |
+
y.append(x)
|
387 |
+
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
|
388 |
+
if len(y) == 2: # segment with (det, proto) output order reversed
|
389 |
+
if len(y[1].shape) != 4:
|
390 |
+
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
|
391 |
+
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
|
392 |
+
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
|
393 |
+
# y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
|
394 |
+
|
395 |
+
# for x in y:
|
396 |
+
# print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes
|
397 |
+
if isinstance(y, (list, tuple)):
|
398 |
+
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
|
399 |
+
else:
|
400 |
+
return self.from_numpy(y)
|
401 |
+
|
402 |
+
def from_numpy(self, x):
|
403 |
+
"""
|
404 |
+
Convert a numpy array to a tensor.
|
405 |
+
|
406 |
+
Args:
|
407 |
+
x (np.ndarray): The array to be converted.
|
408 |
+
|
409 |
+
Returns:
|
410 |
+
(torch.Tensor): The converted tensor
|
411 |
+
"""
|
412 |
+
return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
|
413 |
+
|
414 |
+
def warmup(self, imgsz=(1, 3, 640, 640)):
|
415 |
+
"""
|
416 |
+
Warm up the model by running one forward pass with a dummy input.
|
417 |
+
|
418 |
+
Args:
|
419 |
+
imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
|
420 |
+
|
421 |
+
Returns:
|
422 |
+
(None): This method runs the forward pass and don't return any value
|
423 |
+
"""
|
424 |
+
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
|
425 |
+
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
|
426 |
+
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
427 |
+
for _ in range(2 if self.jit else 1): #
|
428 |
+
self.forward(im) # warmup
|
429 |
+
|
430 |
+
@staticmethod
|
431 |
+
def _apply_default_class_names(data):
|
432 |
+
"""Applies default class names to an input YAML file or returns numerical class names."""
|
433 |
+
with contextlib.suppress(Exception):
|
434 |
+
return yaml_load(check_yaml(data))['names']
|
435 |
+
return {i: f'class{i}' for i in range(999)} # return default if above errors
|
436 |
+
|
437 |
+
@staticmethod
|
438 |
+
def _model_type(p='path/to/model.pt'):
|
439 |
+
"""
|
440 |
+
This function takes a path to a model file and returns the model type
|
441 |
+
|
442 |
+
Args:
|
443 |
+
p: path to the model file. Defaults to path/to/model.pt
|
444 |
+
"""
|
445 |
+
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
446 |
+
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
|
447 |
+
from ultralytics.yolo.engine.exporter import export_formats
|
448 |
+
sf = list(export_formats().Suffix) # export suffixes
|
449 |
+
if not is_url(p, check=False) and not isinstance(p, str):
|
450 |
+
check_suffix(p, sf) # checks
|
451 |
+
url = urlparse(p) # if url may be Triton inference server
|
452 |
+
types = [s in Path(p).name for s in sf]
|
453 |
+
types[8] &= not types[9] # tflite &= not edgetpu
|
454 |
+
triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
|
455 |
+
return types + [triton]
|
ultralytics/nn/autoshape.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Common modules
|
4 |
+
"""
|
5 |
+
|
6 |
+
from copy import copy
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import requests
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from PIL import Image, ImageOps
|
15 |
+
from torch.cuda import amp
|
16 |
+
|
17 |
+
from ultralytics.nn.autobackend import AutoBackend
|
18 |
+
from ultralytics.yolo.data.augment import LetterBox
|
19 |
+
from ultralytics.yolo.utils import LOGGER, colorstr
|
20 |
+
from ultralytics.yolo.utils.files import increment_path
|
21 |
+
from ultralytics.yolo.utils.ops import Profile, make_divisible, non_max_suppression, scale_boxes, xyxy2xywh
|
22 |
+
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
23 |
+
from ultralytics.yolo.utils.torch_utils import copy_attr, smart_inference_mode
|
24 |
+
|
25 |
+
|
26 |
+
class AutoShape(nn.Module):
|
27 |
+
"""YOLOv8 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS."""
|
28 |
+
conf = 0.25 # NMS confidence threshold
|
29 |
+
iou = 0.45 # NMS IoU threshold
|
30 |
+
agnostic = False # NMS class-agnostic
|
31 |
+
multi_label = False # NMS multiple labels per box
|
32 |
+
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
|
33 |
+
max_det = 1000 # maximum number of detections per image
|
34 |
+
amp = False # Automatic Mixed Precision (AMP) inference
|
35 |
+
|
36 |
+
def __init__(self, model, verbose=True):
|
37 |
+
"""Initializes object and copies attributes from model object."""
|
38 |
+
super().__init__()
|
39 |
+
if verbose:
|
40 |
+
LOGGER.info('Adding AutoShape... ')
|
41 |
+
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
|
42 |
+
self.dmb = isinstance(model, AutoBackend) # DetectMultiBackend() instance
|
43 |
+
self.pt = not self.dmb or model.pt # PyTorch model
|
44 |
+
self.model = model.eval()
|
45 |
+
if self.pt:
|
46 |
+
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
47 |
+
m.inplace = False # Detect.inplace=False for safe multithread inference
|
48 |
+
m.export = True # do not output loss values
|
49 |
+
|
50 |
+
def _apply(self, fn):
|
51 |
+
"""Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers."""
|
52 |
+
self = super()._apply(fn)
|
53 |
+
if self.pt:
|
54 |
+
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
55 |
+
m.stride = fn(m.stride)
|
56 |
+
m.grid = list(map(fn, m.grid))
|
57 |
+
if isinstance(m.anchor_grid, list):
|
58 |
+
m.anchor_grid = list(map(fn, m.anchor_grid))
|
59 |
+
return self
|
60 |
+
|
61 |
+
@smart_inference_mode()
|
62 |
+
def forward(self, ims, size=640, augment=False, profile=False):
|
63 |
+
"""Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:."""
|
64 |
+
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
|
65 |
+
# URI: = 'https://ultralytics.com/images/zidane.jpg'
|
66 |
+
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
67 |
+
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
|
68 |
+
# numpy: = np.zeros((640,1280,3)) # HWC
|
69 |
+
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
|
70 |
+
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
71 |
+
|
72 |
+
dt = (Profile(), Profile(), Profile())
|
73 |
+
with dt[0]:
|
74 |
+
if isinstance(size, int): # expand
|
75 |
+
size = (size, size)
|
76 |
+
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
|
77 |
+
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
78 |
+
if isinstance(ims, torch.Tensor): # torch
|
79 |
+
with amp.autocast(autocast):
|
80 |
+
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
|
81 |
+
|
82 |
+
# Preprocess
|
83 |
+
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
|
84 |
+
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
|
85 |
+
for i, im in enumerate(ims):
|
86 |
+
f = f'image{i}' # filename
|
87 |
+
if isinstance(im, (str, Path)): # filename or uri
|
88 |
+
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
|
89 |
+
im = np.asarray(ImageOps.exif_transpose(im))
|
90 |
+
elif isinstance(im, Image.Image): # PIL Image
|
91 |
+
im, f = np.asarray(ImageOps.exif_transpose(im)), getattr(im, 'filename', f) or f
|
92 |
+
files.append(Path(f).with_suffix('.jpg').name)
|
93 |
+
if im.shape[0] < 5: # image in CHW
|
94 |
+
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
|
95 |
+
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
|
96 |
+
s = im.shape[:2] # HWC
|
97 |
+
shape0.append(s) # image shape
|
98 |
+
g = max(size) / max(s) # gain
|
99 |
+
shape1.append([y * g for y in s])
|
100 |
+
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
101 |
+
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
|
102 |
+
x = [LetterBox(shape1, auto=False)(image=im)['img'] for im in ims] # pad
|
103 |
+
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
104 |
+
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
105 |
+
|
106 |
+
with amp.autocast(autocast):
|
107 |
+
# Inference
|
108 |
+
with dt[1]:
|
109 |
+
y = self.model(x, augment=augment) # forward
|
110 |
+
|
111 |
+
# Postprocess
|
112 |
+
with dt[2]:
|
113 |
+
y = non_max_suppression(y if self.dmb else y[0],
|
114 |
+
self.conf,
|
115 |
+
self.iou,
|
116 |
+
self.classes,
|
117 |
+
self.agnostic,
|
118 |
+
self.multi_label,
|
119 |
+
max_det=self.max_det) # NMS
|
120 |
+
for i in range(n):
|
121 |
+
scale_boxes(shape1, y[i][:, :4], shape0[i])
|
122 |
+
|
123 |
+
return Detections(ims, y, files, dt, self.names, x.shape)
|
124 |
+
|
125 |
+
|
126 |
+
class Detections:
|
127 |
+
# YOLOv8 detections class for inference results
|
128 |
+
def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
|
129 |
+
"""Initialize object attributes for YOLO detection results."""
|
130 |
+
super().__init__()
|
131 |
+
d = pred[0].device # device
|
132 |
+
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
|
133 |
+
self.ims = ims # list of images as numpy arrays
|
134 |
+
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
135 |
+
self.names = names # class names
|
136 |
+
self.files = files # image filenames
|
137 |
+
self.times = times # profiling times
|
138 |
+
self.xyxy = pred # xyxy pixels
|
139 |
+
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
|
140 |
+
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
141 |
+
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
142 |
+
self.n = len(self.pred) # number of images (batch size)
|
143 |
+
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
|
144 |
+
self.s = tuple(shape) # inference BCHW shape
|
145 |
+
|
146 |
+
def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
|
147 |
+
"""Return performance metrics and optionally cropped/save images or results."""
|
148 |
+
s, crops = '', []
|
149 |
+
for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
|
150 |
+
s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
|
151 |
+
if pred.shape[0]:
|
152 |
+
for c in pred[:, -1].unique():
|
153 |
+
n = (pred[:, -1] == c).sum() # detections per class
|
154 |
+
s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
|
155 |
+
s = s.rstrip(', ')
|
156 |
+
if show or save or render or crop:
|
157 |
+
annotator = Annotator(im, example=str(self.names))
|
158 |
+
for *box, conf, cls in reversed(pred): # xyxy, confidence, class
|
159 |
+
label = f'{self.names[int(cls)]} {conf:.2f}'
|
160 |
+
if crop:
|
161 |
+
file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
|
162 |
+
crops.append({
|
163 |
+
'box': box,
|
164 |
+
'conf': conf,
|
165 |
+
'cls': cls,
|
166 |
+
'label': label,
|
167 |
+
'im': save_one_box(box, im, file=file, save=save)})
|
168 |
+
else: # all others
|
169 |
+
annotator.box_label(box, label if labels else '', color=colors(cls))
|
170 |
+
im = annotator.im
|
171 |
+
else:
|
172 |
+
s += '(no detections)'
|
173 |
+
|
174 |
+
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
|
175 |
+
if show:
|
176 |
+
im.show(self.files[i]) # show
|
177 |
+
if save:
|
178 |
+
f = self.files[i]
|
179 |
+
im.save(save_dir / f) # save
|
180 |
+
if i == self.n - 1:
|
181 |
+
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
|
182 |
+
if render:
|
183 |
+
self.ims[i] = np.asarray(im)
|
184 |
+
if pprint:
|
185 |
+
s = s.lstrip('\n')
|
186 |
+
return f'{s}\nSpeed: %.1fms preprocess, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
|
187 |
+
if crop:
|
188 |
+
if save:
|
189 |
+
LOGGER.info(f'Saved results to {save_dir}\n')
|
190 |
+
return crops
|
191 |
+
|
192 |
+
def show(self, labels=True):
|
193 |
+
"""Displays YOLO results with detected bounding boxes."""
|
194 |
+
self._run(show=True, labels=labels) # show results
|
195 |
+
|
196 |
+
def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
|
197 |
+
"""Save detection results with optional labels to specified directory."""
|
198 |
+
save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
|
199 |
+
self._run(save=True, labels=labels, save_dir=save_dir) # save results
|
200 |
+
|
201 |
+
def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
|
202 |
+
"""Crops images into detections and saves them if 'save' is True."""
|
203 |
+
save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
|
204 |
+
return self._run(crop=True, save=save, save_dir=save_dir) # crop results
|
205 |
+
|
206 |
+
def render(self, labels=True):
|
207 |
+
"""Renders detected objects and returns images."""
|
208 |
+
self._run(render=True, labels=labels) # render results
|
209 |
+
return self.ims
|
210 |
+
|
211 |
+
def pandas(self):
|
212 |
+
"""Return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])."""
|
213 |
+
import pandas
|
214 |
+
new = copy(self) # return copy
|
215 |
+
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
|
216 |
+
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
|
217 |
+
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
|
218 |
+
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
|
219 |
+
setattr(new, k, [pandas.DataFrame(x, columns=c) for x in a])
|
220 |
+
return new
|
221 |
+
|
222 |
+
def tolist(self):
|
223 |
+
"""Return a list of Detections objects, i.e. 'for result in results.tolist():'."""
|
224 |
+
r = range(self.n) # iterable
|
225 |
+
x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
|
226 |
+
# for d in x:
|
227 |
+
# for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
228 |
+
# setattr(d, k, getattr(d, k)[0]) # pop out of list
|
229 |
+
return x
|
230 |
+
|
231 |
+
def print(self):
|
232 |
+
"""Print the results of the `self._run()` function."""
|
233 |
+
LOGGER.info(self.__str__())
|
234 |
+
|
235 |
+
def __len__(self): # override len(results)
|
236 |
+
return self.n
|
237 |
+
|
238 |
+
def __str__(self): # override print(results)
|
239 |
+
return self._run(pprint=True) # print results
|
240 |
+
|
241 |
+
def __repr__(self):
|
242 |
+
"""Returns a printable representation of the object."""
|
243 |
+
return f'YOLOv8 {self.__class__} instance\n' + self.__str__()
|
ultralytics/nn/modules/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Ultralytics modules. Visualize with:
|
4 |
+
|
5 |
+
from ultralytics.nn.modules import *
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
|
9 |
+
x = torch.ones(1, 128, 40, 40)
|
10 |
+
m = Conv(128, 128)
|
11 |
+
f = f'{m._get_name()}.onnx'
|
12 |
+
torch.onnx.export(m, x, f)
|
13 |
+
os.system(f'onnxsim {f} {f} && open {f}')
|
14 |
+
"""
|
15 |
+
|
16 |
+
from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck,
|
17 |
+
HGBlock, HGStem, Proto, RepC3)
|
18 |
+
from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
|
19 |
+
GhostConv, LightConv, RepConv, SpatialAttention)
|
20 |
+
from .head import Classify, Detect, Pose, RTDETRDecoder, Segment
|
21 |
+
from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
|
22 |
+
MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
|
23 |
+
|
24 |
+
__all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus',
|
25 |
+
'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer',
|
26 |
+
'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3',
|
27 |
+
'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect',
|
28 |
+
'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI',
|
29 |
+
'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP')
|
ultralytics/nn/modules/block.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Block modules
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv
|
11 |
+
from .transformer import TransformerBlock
|
12 |
+
|
13 |
+
__all__ = ('DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost',
|
14 |
+
'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'RepC3')
|
15 |
+
|
16 |
+
|
17 |
+
class DFL(nn.Module):
|
18 |
+
"""
|
19 |
+
Integral module of Distribution Focal Loss (DFL).
|
20 |
+
Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, c1=16):
|
24 |
+
"""Initialize a convolutional layer with a given number of input channels."""
|
25 |
+
super().__init__()
|
26 |
+
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
27 |
+
x = torch.arange(c1, dtype=torch.float)
|
28 |
+
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
|
29 |
+
self.c1 = c1
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
"""Applies a transformer layer on input tensor 'x' and returns a tensor."""
|
33 |
+
b, c, a = x.shape # batch, channels, anchors
|
34 |
+
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
|
35 |
+
# return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
|
36 |
+
|
37 |
+
|
38 |
+
class Proto(nn.Module):
|
39 |
+
"""YOLOv8 mask Proto module for segmentation models."""
|
40 |
+
|
41 |
+
def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
|
42 |
+
super().__init__()
|
43 |
+
self.cv1 = Conv(c1, c_, k=3)
|
44 |
+
self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')
|
45 |
+
self.cv2 = Conv(c_, c_, k=3)
|
46 |
+
self.cv3 = Conv(c_, c2)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
"""Performs a forward pass through layers using an upsampled input image."""
|
50 |
+
return self.cv3(self.cv2(self.upsample(self.cv1(x))))
|
51 |
+
|
52 |
+
|
53 |
+
class HGStem(nn.Module):
|
54 |
+
"""StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.
|
55 |
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, c1, cm, c2):
|
59 |
+
super().__init__()
|
60 |
+
self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU())
|
61 |
+
self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU())
|
62 |
+
self.stem2b = Conv(cm // 2, cm, 2, 1, 0, act=nn.ReLU())
|
63 |
+
self.stem3 = Conv(cm * 2, cm, 3, 2, act=nn.ReLU())
|
64 |
+
self.stem4 = Conv(cm, c2, 1, 1, act=nn.ReLU())
|
65 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
"""Forward pass of a PPHGNetV2 backbone layer."""
|
69 |
+
x = self.stem1(x)
|
70 |
+
x = F.pad(x, [0, 1, 0, 1])
|
71 |
+
x2 = self.stem2a(x)
|
72 |
+
x2 = F.pad(x2, [0, 1, 0, 1])
|
73 |
+
x2 = self.stem2b(x2)
|
74 |
+
x1 = self.pool(x)
|
75 |
+
x = torch.cat([x1, x2], dim=1)
|
76 |
+
x = self.stem3(x)
|
77 |
+
x = self.stem4(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class HGBlock(nn.Module):
|
82 |
+
"""HG_Block of PPHGNetV2 with 2 convolutions and LightConv.
|
83 |
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()):
|
87 |
+
super().__init__()
|
88 |
+
block = LightConv if lightconv else Conv
|
89 |
+
self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n))
|
90 |
+
self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act) # squeeze conv
|
91 |
+
self.ec = Conv(c2 // 2, c2, 1, 1, act=act) # excitation conv
|
92 |
+
self.add = shortcut and c1 == c2
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
"""Forward pass of a PPHGNetV2 backbone layer."""
|
96 |
+
y = [x]
|
97 |
+
y.extend(m(y[-1]) for m in self.m)
|
98 |
+
y = self.ec(self.sc(torch.cat(y, 1)))
|
99 |
+
return y + x if self.add else y
|
100 |
+
|
101 |
+
|
102 |
+
class SPP(nn.Module):
|
103 |
+
"""Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729."""
|
104 |
+
|
105 |
+
def __init__(self, c1, c2, k=(5, 9, 13)):
|
106 |
+
"""Initialize the SPP layer with input/output channels and pooling kernel sizes."""
|
107 |
+
super().__init__()
|
108 |
+
c_ = c1 // 2 # hidden channels
|
109 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
110 |
+
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
|
111 |
+
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
"""Forward pass of the SPP layer, performing spatial pyramid pooling."""
|
115 |
+
x = self.cv1(x)
|
116 |
+
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
|
117 |
+
|
118 |
+
|
119 |
+
class SPPF(nn.Module):
|
120 |
+
"""Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
|
121 |
+
|
122 |
+
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
|
123 |
+
super().__init__()
|
124 |
+
c_ = c1 // 2 # hidden channels
|
125 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
126 |
+
self.cv2 = Conv(c_ * 4, c2, 1, 1)
|
127 |
+
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
"""Forward pass through Ghost Convolution block."""
|
131 |
+
x = self.cv1(x)
|
132 |
+
y1 = self.m(x)
|
133 |
+
y2 = self.m(y1)
|
134 |
+
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
|
135 |
+
|
136 |
+
|
137 |
+
class C1(nn.Module):
|
138 |
+
"""CSP Bottleneck with 1 convolution."""
|
139 |
+
|
140 |
+
def __init__(self, c1, c2, n=1): # ch_in, ch_out, number
|
141 |
+
super().__init__()
|
142 |
+
self.cv1 = Conv(c1, c2, 1, 1)
|
143 |
+
self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n)))
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
"""Applies cross-convolutions to input in the C3 module."""
|
147 |
+
y = self.cv1(x)
|
148 |
+
return self.m(y) + y
|
149 |
+
|
150 |
+
|
151 |
+
class C2(nn.Module):
|
152 |
+
"""CSP Bottleneck with 2 convolutions."""
|
153 |
+
|
154 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
155 |
+
super().__init__()
|
156 |
+
self.c = int(c2 * e) # hidden channels
|
157 |
+
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
158 |
+
self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2)
|
159 |
+
# self.attention = ChannelAttention(2 * self.c) # or SpatialAttention()
|
160 |
+
self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
"""Forward pass through the CSP bottleneck with 2 convolutions."""
|
164 |
+
a, b = self.cv1(x).chunk(2, 1)
|
165 |
+
return self.cv2(torch.cat((self.m(a), b), 1))
|
166 |
+
|
167 |
+
|
168 |
+
class C2f(nn.Module):
|
169 |
+
"""CSP Bottleneck with 2 convolutions."""
|
170 |
+
|
171 |
+
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
172 |
+
super().__init__()
|
173 |
+
self.c = int(c2 * e) # hidden channels
|
174 |
+
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
175 |
+
self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
|
176 |
+
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
"""Forward pass through C2f layer."""
|
180 |
+
y = list(self.cv1(x).chunk(2, 1))
|
181 |
+
y.extend(m(y[-1]) for m in self.m)
|
182 |
+
return self.cv2(torch.cat(y, 1))
|
183 |
+
|
184 |
+
def forward_split(self, x):
|
185 |
+
"""Forward pass using split() instead of chunk()."""
|
186 |
+
y = list(self.cv1(x).split((self.c, self.c), 1))
|
187 |
+
y.extend(m(y[-1]) for m in self.m)
|
188 |
+
return self.cv2(torch.cat(y, 1))
|
189 |
+
|
190 |
+
|
191 |
+
class C3(nn.Module):
|
192 |
+
"""CSP Bottleneck with 3 convolutions."""
|
193 |
+
|
194 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
195 |
+
super().__init__()
|
196 |
+
c_ = int(c2 * e) # hidden channels
|
197 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
198 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
199 |
+
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
|
200 |
+
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
"""Forward pass through the CSP bottleneck with 2 convolutions."""
|
204 |
+
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
205 |
+
|
206 |
+
|
207 |
+
class C3x(C3):
|
208 |
+
"""C3 module with cross-convolutions."""
|
209 |
+
|
210 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
211 |
+
"""Initialize C3TR instance and set default parameters."""
|
212 |
+
super().__init__(c1, c2, n, shortcut, g, e)
|
213 |
+
self.c_ = int(c2 * e)
|
214 |
+
self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n)))
|
215 |
+
|
216 |
+
|
217 |
+
class RepC3(nn.Module):
|
218 |
+
"""Rep C3."""
|
219 |
+
|
220 |
+
def __init__(self, c1, c2, n=3, e=1.0):
|
221 |
+
super().__init__()
|
222 |
+
c_ = int(c2 * e) # hidden channels
|
223 |
+
self.cv1 = Conv(c1, c2, 1, 1)
|
224 |
+
self.cv2 = Conv(c1, c2, 1, 1)
|
225 |
+
self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)])
|
226 |
+
self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity()
|
227 |
+
|
228 |
+
def forward(self, x):
|
229 |
+
"""Forward pass of RT-DETR neck layer."""
|
230 |
+
return self.cv3(self.m(self.cv1(x)) + self.cv2(x))
|
231 |
+
|
232 |
+
|
233 |
+
class C3TR(C3):
|
234 |
+
"""C3 module with TransformerBlock()."""
|
235 |
+
|
236 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
237 |
+
"""Initialize C3Ghost module with GhostBottleneck()."""
|
238 |
+
super().__init__(c1, c2, n, shortcut, g, e)
|
239 |
+
c_ = int(c2 * e)
|
240 |
+
self.m = TransformerBlock(c_, c_, 4, n)
|
241 |
+
|
242 |
+
|
243 |
+
class C3Ghost(C3):
|
244 |
+
"""C3 module with GhostBottleneck()."""
|
245 |
+
|
246 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
247 |
+
"""Initialize 'SPP' module with various pooling sizes for spatial pyramid pooling."""
|
248 |
+
super().__init__(c1, c2, n, shortcut, g, e)
|
249 |
+
c_ = int(c2 * e) # hidden channels
|
250 |
+
self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
|
251 |
+
|
252 |
+
|
253 |
+
class GhostBottleneck(nn.Module):
|
254 |
+
"""Ghost Bottleneck https://github.com/huawei-noah/ghostnet."""
|
255 |
+
|
256 |
+
def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
|
257 |
+
super().__init__()
|
258 |
+
c_ = c2 // 2
|
259 |
+
self.conv = nn.Sequential(
|
260 |
+
GhostConv(c1, c_, 1, 1), # pw
|
261 |
+
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
|
262 |
+
GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
|
263 |
+
self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
|
264 |
+
act=False)) if s == 2 else nn.Identity()
|
265 |
+
|
266 |
+
def forward(self, x):
|
267 |
+
"""Applies skip connection and concatenation to input tensor."""
|
268 |
+
return self.conv(x) + self.shortcut(x)
|
269 |
+
|
270 |
+
|
271 |
+
class Bottleneck(nn.Module):
|
272 |
+
"""Standard bottleneck."""
|
273 |
+
|
274 |
+
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, groups, kernels, expand
|
275 |
+
super().__init__()
|
276 |
+
c_ = int(c2 * e) # hidden channels
|
277 |
+
self.cv1 = Conv(c1, c_, k[0], 1)
|
278 |
+
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
279 |
+
self.add = shortcut and c1 == c2
|
280 |
+
|
281 |
+
def forward(self, x):
|
282 |
+
"""'forward()' applies the YOLOv5 FPN to input data."""
|
283 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
284 |
+
|
285 |
+
|
286 |
+
class BottleneckCSP(nn.Module):
|
287 |
+
"""CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks."""
|
288 |
+
|
289 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
290 |
+
super().__init__()
|
291 |
+
c_ = int(c2 * e) # hidden channels
|
292 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
293 |
+
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
|
294 |
+
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
|
295 |
+
self.cv4 = Conv(2 * c_, c2, 1, 1)
|
296 |
+
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
|
297 |
+
self.act = nn.SiLU()
|
298 |
+
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
299 |
+
|
300 |
+
def forward(self, x):
|
301 |
+
"""Applies a CSP bottleneck with 3 convolutions."""
|
302 |
+
y1 = self.cv3(self.m(self.cv1(x)))
|
303 |
+
y2 = self.cv2(x)
|
304 |
+
return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
|
ultralytics/nn/modules/conv.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Convolution modules
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
__all__ = ('Conv', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
|
13 |
+
'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
|
14 |
+
|
15 |
+
|
16 |
+
def autopad(k, p=None, d=1): # kernel, padding, dilation
|
17 |
+
"""Pad to 'same' shape outputs."""
|
18 |
+
if d > 1:
|
19 |
+
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
|
20 |
+
if p is None:
|
21 |
+
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
|
22 |
+
return p
|
23 |
+
|
24 |
+
|
25 |
+
class Conv(nn.Module):
|
26 |
+
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
|
27 |
+
default_act = nn.SiLU() # default activation
|
28 |
+
|
29 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
30 |
+
"""Initialize Conv layer with given arguments including activation."""
|
31 |
+
super().__init__()
|
32 |
+
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
|
33 |
+
self.bn = nn.BatchNorm2d(c2)
|
34 |
+
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
"""Apply convolution, batch normalization and activation to input tensor."""
|
38 |
+
return self.act(self.bn(self.conv(x)))
|
39 |
+
|
40 |
+
def forward_fuse(self, x):
|
41 |
+
"""Perform transposed convolution of 2D data."""
|
42 |
+
return self.act(self.conv(x))
|
43 |
+
|
44 |
+
|
45 |
+
class Conv2(Conv):
|
46 |
+
"""Simplified RepConv module with Conv fusing."""
|
47 |
+
|
48 |
+
def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
|
49 |
+
"""Initialize Conv layer with given arguments including activation."""
|
50 |
+
super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
|
51 |
+
self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
"""Apply convolution, batch normalization and activation to input tensor."""
|
55 |
+
return self.act(self.bn(self.conv(x) + self.cv2(x)))
|
56 |
+
|
57 |
+
def fuse_convs(self):
|
58 |
+
"""Fuse parallel convolutions."""
|
59 |
+
w = torch.zeros_like(self.conv.weight.data)
|
60 |
+
i = [x // 2 for x in w.shape[2:]]
|
61 |
+
w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
|
62 |
+
self.conv.weight.data += w
|
63 |
+
self.__delattr__('cv2')
|
64 |
+
|
65 |
+
|
66 |
+
class LightConv(nn.Module):
|
67 |
+
"""Light convolution with args(ch_in, ch_out, kernel).
|
68 |
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, c1, c2, k=1, act=nn.ReLU()):
|
72 |
+
"""Initialize Conv layer with given arguments including activation."""
|
73 |
+
super().__init__()
|
74 |
+
self.conv1 = Conv(c1, c2, 1, act=False)
|
75 |
+
self.conv2 = DWConv(c2, c2, k, act=act)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
"""Apply 2 convolutions to input tensor."""
|
79 |
+
return self.conv2(self.conv1(x))
|
80 |
+
|
81 |
+
|
82 |
+
class DWConv(Conv):
|
83 |
+
"""Depth-wise convolution."""
|
84 |
+
|
85 |
+
def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
|
86 |
+
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
|
87 |
+
|
88 |
+
|
89 |
+
class DWConvTranspose2d(nn.ConvTranspose2d):
|
90 |
+
"""Depth-wise transpose convolution."""
|
91 |
+
|
92 |
+
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
|
93 |
+
super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
|
94 |
+
|
95 |
+
|
96 |
+
class ConvTranspose(nn.Module):
|
97 |
+
"""Convolution transpose 2d layer."""
|
98 |
+
default_act = nn.SiLU() # default activation
|
99 |
+
|
100 |
+
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
|
101 |
+
"""Initialize ConvTranspose2d layer with batch normalization and activation function."""
|
102 |
+
super().__init__()
|
103 |
+
self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
|
104 |
+
self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
|
105 |
+
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
"""Applies transposed convolutions, batch normalization and activation to input."""
|
109 |
+
return self.act(self.bn(self.conv_transpose(x)))
|
110 |
+
|
111 |
+
def forward_fuse(self, x):
|
112 |
+
"""Applies activation and convolution transpose operation to input."""
|
113 |
+
return self.act(self.conv_transpose(x))
|
114 |
+
|
115 |
+
|
116 |
+
class Focus(nn.Module):
|
117 |
+
"""Focus wh information into c-space."""
|
118 |
+
|
119 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
120 |
+
super().__init__()
|
121 |
+
self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)
|
122 |
+
# self.contract = Contract(gain=2)
|
123 |
+
|
124 |
+
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
|
125 |
+
return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
|
126 |
+
# return self.conv(self.contract(x))
|
127 |
+
|
128 |
+
|
129 |
+
class GhostConv(nn.Module):
|
130 |
+
"""Ghost Convolution https://github.com/huawei-noah/ghostnet."""
|
131 |
+
|
132 |
+
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
|
133 |
+
super().__init__()
|
134 |
+
c_ = c2 // 2 # hidden channels
|
135 |
+
self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
|
136 |
+
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
"""Forward propagation through a Ghost Bottleneck layer with skip connection."""
|
140 |
+
y = self.cv1(x)
|
141 |
+
return torch.cat((y, self.cv2(y)), 1)
|
142 |
+
|
143 |
+
|
144 |
+
class RepConv(nn.Module):
|
145 |
+
"""RepConv is a basic rep-style block, including training and deploy status
|
146 |
+
This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
|
147 |
+
"""
|
148 |
+
default_act = nn.SiLU() # default activation
|
149 |
+
|
150 |
+
def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
|
151 |
+
super().__init__()
|
152 |
+
assert k == 3 and p == 1
|
153 |
+
self.g = g
|
154 |
+
self.c1 = c1
|
155 |
+
self.c2 = c2
|
156 |
+
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
157 |
+
|
158 |
+
self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None
|
159 |
+
self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
|
160 |
+
self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
|
161 |
+
|
162 |
+
def forward_fuse(self, x):
|
163 |
+
"""Forward process"""
|
164 |
+
return self.act(self.conv(x))
|
165 |
+
|
166 |
+
def forward(self, x):
|
167 |
+
"""Forward process"""
|
168 |
+
id_out = 0 if self.bn is None else self.bn(x)
|
169 |
+
return self.act(self.conv1(x) + self.conv2(x) + id_out)
|
170 |
+
|
171 |
+
def get_equivalent_kernel_bias(self):
|
172 |
+
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
|
173 |
+
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
|
174 |
+
kernelid, biasid = self._fuse_bn_tensor(self.bn)
|
175 |
+
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
|
176 |
+
|
177 |
+
def _avg_to_3x3_tensor(self, avgp):
|
178 |
+
channels = self.c1
|
179 |
+
groups = self.g
|
180 |
+
kernel_size = avgp.kernel_size
|
181 |
+
input_dim = channels // groups
|
182 |
+
k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
|
183 |
+
k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
|
184 |
+
return k
|
185 |
+
|
186 |
+
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
187 |
+
if kernel1x1 is None:
|
188 |
+
return 0
|
189 |
+
else:
|
190 |
+
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
191 |
+
|
192 |
+
def _fuse_bn_tensor(self, branch):
|
193 |
+
if branch is None:
|
194 |
+
return 0, 0
|
195 |
+
if isinstance(branch, Conv):
|
196 |
+
kernel = branch.conv.weight
|
197 |
+
running_mean = branch.bn.running_mean
|
198 |
+
running_var = branch.bn.running_var
|
199 |
+
gamma = branch.bn.weight
|
200 |
+
beta = branch.bn.bias
|
201 |
+
eps = branch.bn.eps
|
202 |
+
elif isinstance(branch, nn.BatchNorm2d):
|
203 |
+
if not hasattr(self, 'id_tensor'):
|
204 |
+
input_dim = self.c1 // self.g
|
205 |
+
kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
|
206 |
+
for i in range(self.c1):
|
207 |
+
kernel_value[i, i % input_dim, 1, 1] = 1
|
208 |
+
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
|
209 |
+
kernel = self.id_tensor
|
210 |
+
running_mean = branch.running_mean
|
211 |
+
running_var = branch.running_var
|
212 |
+
gamma = branch.weight
|
213 |
+
beta = branch.bias
|
214 |
+
eps = branch.eps
|
215 |
+
std = (running_var + eps).sqrt()
|
216 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
217 |
+
return kernel * t, beta - running_mean * gamma / std
|
218 |
+
|
219 |
+
def fuse_convs(self):
|
220 |
+
if hasattr(self, 'conv'):
|
221 |
+
return
|
222 |
+
kernel, bias = self.get_equivalent_kernel_bias()
|
223 |
+
self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
|
224 |
+
out_channels=self.conv1.conv.out_channels,
|
225 |
+
kernel_size=self.conv1.conv.kernel_size,
|
226 |
+
stride=self.conv1.conv.stride,
|
227 |
+
padding=self.conv1.conv.padding,
|
228 |
+
dilation=self.conv1.conv.dilation,
|
229 |
+
groups=self.conv1.conv.groups,
|
230 |
+
bias=True).requires_grad_(False)
|
231 |
+
self.conv.weight.data = kernel
|
232 |
+
self.conv.bias.data = bias
|
233 |
+
for para in self.parameters():
|
234 |
+
para.detach_()
|
235 |
+
self.__delattr__('conv1')
|
236 |
+
self.__delattr__('conv2')
|
237 |
+
if hasattr(self, 'nm'):
|
238 |
+
self.__delattr__('nm')
|
239 |
+
if hasattr(self, 'bn'):
|
240 |
+
self.__delattr__('bn')
|
241 |
+
if hasattr(self, 'id_tensor'):
|
242 |
+
self.__delattr__('id_tensor')
|
243 |
+
|
244 |
+
|
245 |
+
class ChannelAttention(nn.Module):
|
246 |
+
"""Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""
|
247 |
+
|
248 |
+
def __init__(self, channels: int) -> None:
|
249 |
+
super().__init__()
|
250 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
251 |
+
self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
|
252 |
+
self.act = nn.Sigmoid()
|
253 |
+
|
254 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
255 |
+
return x * self.act(self.fc(self.pool(x)))
|
256 |
+
|
257 |
+
|
258 |
+
class SpatialAttention(nn.Module):
|
259 |
+
"""Spatial-attention module."""
|
260 |
+
|
261 |
+
def __init__(self, kernel_size=7):
|
262 |
+
"""Initialize Spatial-attention module with kernel size argument."""
|
263 |
+
super().__init__()
|
264 |
+
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
265 |
+
padding = 3 if kernel_size == 7 else 1
|
266 |
+
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
267 |
+
self.act = nn.Sigmoid()
|
268 |
+
|
269 |
+
def forward(self, x):
|
270 |
+
"""Apply channel and spatial attention on input for feature recalibration."""
|
271 |
+
return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
|
272 |
+
|
273 |
+
|
274 |
+
class CBAM(nn.Module):
|
275 |
+
"""Convolutional Block Attention Module."""
|
276 |
+
|
277 |
+
def __init__(self, c1, kernel_size=7): # ch_in, kernels
|
278 |
+
super().__init__()
|
279 |
+
self.channel_attention = ChannelAttention(c1)
|
280 |
+
self.spatial_attention = SpatialAttention(kernel_size)
|
281 |
+
|
282 |
+
def forward(self, x):
|
283 |
+
"""Applies the forward pass through C1 module."""
|
284 |
+
return self.spatial_attention(self.channel_attention(x))
|
285 |
+
|
286 |
+
|
287 |
+
class Concat(nn.Module):
|
288 |
+
"""Concatenate a list of tensors along dimension."""
|
289 |
+
|
290 |
+
def __init__(self, dimension=1):
|
291 |
+
"""Concatenates a list of tensors along a specified dimension."""
|
292 |
+
super().__init__()
|
293 |
+
self.d = dimension
|
294 |
+
|
295 |
+
def forward(self, x):
|
296 |
+
"""Forward pass for the YOLOv8 mask Proto module."""
|
297 |
+
return torch.cat(x, self.d)
|
ultralytics/nn/modules/head.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Model head modules
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.nn.init import constant_, xavier_uniform_
|
11 |
+
|
12 |
+
from ultralytics.yolo.utils.tal import dist2bbox, make_anchors
|
13 |
+
|
14 |
+
from .block import DFL, Proto
|
15 |
+
from .conv import Conv
|
16 |
+
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
|
17 |
+
from .utils import bias_init_with_prob, linear_init_
|
18 |
+
|
19 |
+
__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'RTDETRDecoder'
|
20 |
+
|
21 |
+
|
22 |
+
class Detect(nn.Module):
|
23 |
+
"""YOLOv8 Detect head for detection models."""
|
24 |
+
dynamic = False # force grid reconstruction
|
25 |
+
export = False # export mode
|
26 |
+
shape = None
|
27 |
+
anchors = torch.empty(0) # init
|
28 |
+
strides = torch.empty(0) # init
|
29 |
+
|
30 |
+
def __init__(self, nc=80, ch=()): # detection layer
|
31 |
+
super().__init__()
|
32 |
+
self.nc = nc # number of classes
|
33 |
+
self.nl = len(ch) # number of detection layers
|
34 |
+
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
|
35 |
+
self.no = nc + self.reg_max * 4 # number of outputs per anchor
|
36 |
+
self.stride = torch.zeros(self.nl) # strides computed during build
|
37 |
+
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], self.nc) # channels
|
38 |
+
self.cv2 = nn.ModuleList(
|
39 |
+
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
|
40 |
+
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
|
41 |
+
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
"""Concatenates and returns predicted bounding boxes and class probabilities."""
|
45 |
+
shape = x[0].shape # BCHW
|
46 |
+
for i in range(self.nl):
|
47 |
+
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
|
48 |
+
if self.training:
|
49 |
+
return x
|
50 |
+
elif self.dynamic or self.shape != shape:
|
51 |
+
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
52 |
+
self.shape = shape
|
53 |
+
|
54 |
+
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
55 |
+
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
|
56 |
+
box = x_cat[:, :self.reg_max * 4]
|
57 |
+
cls = x_cat[:, self.reg_max * 4:]
|
58 |
+
else:
|
59 |
+
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
60 |
+
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
61 |
+
y = torch.cat((dbox, cls.sigmoid()), 1)
|
62 |
+
return y if self.export else (y, x)
|
63 |
+
|
64 |
+
def bias_init(self):
|
65 |
+
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
66 |
+
m = self # self.model[-1] # Detect() module
|
67 |
+
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
68 |
+
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
69 |
+
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
70 |
+
a[-1].bias.data[:] = 1.0 # box
|
71 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
72 |
+
|
73 |
+
|
74 |
+
class Segment(Detect):
|
75 |
+
"""YOLOv8 Segment head for segmentation models."""
|
76 |
+
|
77 |
+
def __init__(self, nc=80, nm=32, npr=256, ch=()):
|
78 |
+
"""Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
|
79 |
+
super().__init__(nc, ch)
|
80 |
+
self.nm = nm # number of masks
|
81 |
+
self.npr = npr # number of protos
|
82 |
+
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
83 |
+
self.detect = Detect.forward
|
84 |
+
|
85 |
+
c4 = max(ch[0] // 4, self.nm)
|
86 |
+
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
"""Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
|
90 |
+
p = self.proto(x[0]) # mask protos
|
91 |
+
bs = p.shape[0] # batch size
|
92 |
+
|
93 |
+
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
94 |
+
x = self.detect(self, x)
|
95 |
+
if self.training:
|
96 |
+
return x, mc, p
|
97 |
+
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
|
98 |
+
|
99 |
+
|
100 |
+
class Pose(Detect):
|
101 |
+
"""YOLOv8 Pose head for keypoints models."""
|
102 |
+
|
103 |
+
def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
|
104 |
+
"""Initialize YOLO network with default parameters and Convolutional Layers."""
|
105 |
+
super().__init__(nc, ch)
|
106 |
+
self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
107 |
+
self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
|
108 |
+
self.detect = Detect.forward
|
109 |
+
|
110 |
+
c4 = max(ch[0] // 4, self.nk)
|
111 |
+
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
"""Perform forward pass through YOLO model and return predictions."""
|
115 |
+
bs = x[0].shape[0] # batch size
|
116 |
+
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
|
117 |
+
x = self.detect(self, x)
|
118 |
+
if self.training:
|
119 |
+
return x, kpt
|
120 |
+
pred_kpt = self.kpts_decode(bs, kpt)
|
121 |
+
return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
|
122 |
+
|
123 |
+
def kpts_decode(self, bs, kpts):
|
124 |
+
"""Decodes keypoints."""
|
125 |
+
ndim = self.kpt_shape[1]
|
126 |
+
if self.export: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
|
127 |
+
y = kpts.view(bs, *self.kpt_shape, -1)
|
128 |
+
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
|
129 |
+
if ndim == 3:
|
130 |
+
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
|
131 |
+
return a.view(bs, self.nk, -1)
|
132 |
+
else:
|
133 |
+
y = kpts.clone()
|
134 |
+
if ndim == 3:
|
135 |
+
y[:, 2::3].sigmoid_() # inplace sigmoid
|
136 |
+
y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
|
137 |
+
y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
|
138 |
+
return y
|
139 |
+
|
140 |
+
|
141 |
+
class Classify(nn.Module):
|
142 |
+
"""YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
|
143 |
+
|
144 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
145 |
+
super().__init__()
|
146 |
+
c_ = 1280 # efficientnet_b0 size
|
147 |
+
self.conv = Conv(c1, c_, k, s, p, g)
|
148 |
+
self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
|
149 |
+
self.drop = nn.Dropout(p=0.0, inplace=True)
|
150 |
+
self.linear = nn.Linear(c_, c2) # to x(b,c2)
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
"""Performs a forward pass of the YOLO model on input image data."""
|
154 |
+
if isinstance(x, list):
|
155 |
+
x = torch.cat(x, 1)
|
156 |
+
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|
157 |
+
return x if self.training else x.softmax(1)
|
158 |
+
|
159 |
+
|
160 |
+
class RTDETRDecoder(nn.Module):
|
161 |
+
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
nc=80,
|
165 |
+
ch=(512, 1024, 2048),
|
166 |
+
hidden_dim=256,
|
167 |
+
num_queries=300,
|
168 |
+
strides=(8, 16, 32), # TODO
|
169 |
+
nl=3,
|
170 |
+
num_decoder_points=4,
|
171 |
+
nhead=8,
|
172 |
+
num_decoder_layers=6,
|
173 |
+
dim_feedforward=1024,
|
174 |
+
dropout=0.,
|
175 |
+
act=nn.ReLU(),
|
176 |
+
eval_idx=-1,
|
177 |
+
# training args
|
178 |
+
num_denoising=100,
|
179 |
+
label_noise_ratio=0.5,
|
180 |
+
box_noise_scale=1.0,
|
181 |
+
learnt_init_query=False):
|
182 |
+
super().__init__()
|
183 |
+
assert len(ch) <= nl
|
184 |
+
assert len(strides) == len(ch)
|
185 |
+
for _ in range(nl - len(strides)):
|
186 |
+
strides.append(strides[-1] * 2)
|
187 |
+
|
188 |
+
self.hidden_dim = hidden_dim
|
189 |
+
self.nhead = nhead
|
190 |
+
self.feat_strides = strides
|
191 |
+
self.nl = nl
|
192 |
+
self.nc = nc
|
193 |
+
self.num_queries = num_queries
|
194 |
+
self.num_decoder_layers = num_decoder_layers
|
195 |
+
|
196 |
+
# backbone feature projection
|
197 |
+
self._build_input_proj_layer(ch)
|
198 |
+
|
199 |
+
# Transformer module
|
200 |
+
decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, act, nl,
|
201 |
+
num_decoder_points)
|
202 |
+
self.decoder = DeformableTransformerDecoder(hidden_dim, decoder_layer, num_decoder_layers, eval_idx)
|
203 |
+
|
204 |
+
# denoising part
|
205 |
+
self.denoising_class_embed = nn.Embedding(nc, hidden_dim)
|
206 |
+
self.num_denoising = num_denoising
|
207 |
+
self.label_noise_ratio = label_noise_ratio
|
208 |
+
self.box_noise_scale = box_noise_scale
|
209 |
+
|
210 |
+
# decoder embedding
|
211 |
+
self.learnt_init_query = learnt_init_query
|
212 |
+
if learnt_init_query:
|
213 |
+
self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
|
214 |
+
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
|
215 |
+
|
216 |
+
# encoder head
|
217 |
+
self.enc_output = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim))
|
218 |
+
self.enc_score_head = nn.Linear(hidden_dim, nc)
|
219 |
+
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
|
220 |
+
|
221 |
+
# decoder head
|
222 |
+
self.dec_score_head = nn.ModuleList([nn.Linear(hidden_dim, nc) for _ in range(num_decoder_layers)])
|
223 |
+
self.dec_bbox_head = nn.ModuleList([
|
224 |
+
MLP(hidden_dim, hidden_dim, 4, num_layers=3) for _ in range(num_decoder_layers)])
|
225 |
+
|
226 |
+
self._reset_parameters()
|
227 |
+
|
228 |
+
def forward(self, feats, gt_meta=None):
|
229 |
+
# input projection and embedding
|
230 |
+
memory, spatial_shapes, _ = self._get_encoder_input(feats)
|
231 |
+
|
232 |
+
# prepare denoising training
|
233 |
+
if self.training:
|
234 |
+
raise NotImplementedError
|
235 |
+
# denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
|
236 |
+
# get_contrastive_denoising_training_group(gt_meta,
|
237 |
+
# self.num_classes,
|
238 |
+
# self.num_queries,
|
239 |
+
# self.denoising_class_embed.weight,
|
240 |
+
# self.num_denoising,
|
241 |
+
# self.label_noise_ratio,
|
242 |
+
# self.box_noise_scale)
|
243 |
+
else:
|
244 |
+
denoising_class, denoising_bbox_unact, attn_mask = None, None, None
|
245 |
+
|
246 |
+
target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
|
247 |
+
self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact)
|
248 |
+
|
249 |
+
# decoder
|
250 |
+
out_bboxes, out_logits = self.decoder(target,
|
251 |
+
init_ref_points_unact,
|
252 |
+
memory,
|
253 |
+
spatial_shapes,
|
254 |
+
self.dec_bbox_head,
|
255 |
+
self.dec_score_head,
|
256 |
+
self.query_pos_head,
|
257 |
+
attn_mask=attn_mask)
|
258 |
+
if not self.training:
|
259 |
+
out_logits = out_logits.sigmoid_()
|
260 |
+
return out_bboxes, out_logits # enc_topk_bboxes, enc_topk_logits, dn_meta
|
261 |
+
|
262 |
+
def _reset_parameters(self):
|
263 |
+
# class and bbox head init
|
264 |
+
bias_cls = bias_init_with_prob(0.01)
|
265 |
+
linear_init_(self.enc_score_head)
|
266 |
+
constant_(self.enc_score_head.bias, bias_cls)
|
267 |
+
constant_(self.enc_bbox_head.layers[-1].weight, 0.)
|
268 |
+
constant_(self.enc_bbox_head.layers[-1].bias, 0.)
|
269 |
+
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
|
270 |
+
linear_init_(cls_)
|
271 |
+
constant_(cls_.bias, bias_cls)
|
272 |
+
constant_(reg_.layers[-1].weight, 0.)
|
273 |
+
constant_(reg_.layers[-1].bias, 0.)
|
274 |
+
|
275 |
+
linear_init_(self.enc_output[0])
|
276 |
+
xavier_uniform_(self.enc_output[0].weight)
|
277 |
+
if self.learnt_init_query:
|
278 |
+
xavier_uniform_(self.tgt_embed.weight)
|
279 |
+
xavier_uniform_(self.query_pos_head.layers[0].weight)
|
280 |
+
xavier_uniform_(self.query_pos_head.layers[1].weight)
|
281 |
+
for layer in self.input_proj:
|
282 |
+
xavier_uniform_(layer[0].weight)
|
283 |
+
|
284 |
+
def _build_input_proj_layer(self, ch):
|
285 |
+
self.input_proj = nn.ModuleList()
|
286 |
+
for in_channels in ch:
|
287 |
+
self.input_proj.append(
|
288 |
+
nn.Sequential(nn.Conv2d(in_channels, self.hidden_dim, kernel_size=1, bias=False),
|
289 |
+
nn.BatchNorm2d(self.hidden_dim)))
|
290 |
+
in_channels = ch[-1]
|
291 |
+
for _ in range(self.nl - len(ch)):
|
292 |
+
self.input_proj.append(
|
293 |
+
nn.Sequential(nn.Conv2D(in_channels, self.hidden_dim, kernel_size=3, stride=2, padding=1, bias=False),
|
294 |
+
nn.BatchNorm2d(self.hidden_dim)))
|
295 |
+
in_channels = self.hidden_dim
|
296 |
+
|
297 |
+
def _generate_anchors(self, spatial_shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
|
298 |
+
anchors = []
|
299 |
+
for lvl, (h, w) in enumerate(spatial_shapes):
|
300 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(end=h, dtype=torch.float32),
|
301 |
+
torch.arange(end=w, dtype=torch.float32),
|
302 |
+
indexing='ij')
|
303 |
+
grid_xy = torch.stack([grid_x, grid_y], -1)
|
304 |
+
|
305 |
+
valid_WH = torch.tensor([h, w]).to(torch.float32)
|
306 |
+
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
|
307 |
+
wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl)
|
308 |
+
anchors.append(torch.concat([grid_xy, wh], -1).reshape([-1, h * w, 4]))
|
309 |
+
|
310 |
+
anchors = torch.concat(anchors, 1)
|
311 |
+
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
|
312 |
+
anchors = torch.log(anchors / (1 - anchors))
|
313 |
+
anchors = torch.where(valid_mask, anchors, torch.inf)
|
314 |
+
return anchors.to(device=device, dtype=dtype), valid_mask.to(device=device)
|
315 |
+
|
316 |
+
def _get_encoder_input(self, feats):
|
317 |
+
# get projection features
|
318 |
+
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
|
319 |
+
if self.nl > len(proj_feats):
|
320 |
+
len_srcs = len(proj_feats)
|
321 |
+
for i in range(len_srcs, self.nl):
|
322 |
+
if i == len_srcs:
|
323 |
+
proj_feats.append(self.input_proj[i](feats[-1]))
|
324 |
+
else:
|
325 |
+
proj_feats.append(self.input_proj[i](proj_feats[-1]))
|
326 |
+
|
327 |
+
# get encoder inputs
|
328 |
+
feat_flatten = []
|
329 |
+
spatial_shapes = []
|
330 |
+
level_start_index = [0]
|
331 |
+
for feat in proj_feats:
|
332 |
+
_, _, h, w = feat.shape
|
333 |
+
# [b, c, h, w] -> [b, h*w, c]
|
334 |
+
feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
|
335 |
+
# [nl, 2]
|
336 |
+
spatial_shapes.append([h, w])
|
337 |
+
# [l], start index of each level
|
338 |
+
level_start_index.append(h * w + level_start_index[-1])
|
339 |
+
|
340 |
+
# [b, l, c]
|
341 |
+
feat_flatten = torch.concat(feat_flatten, 1)
|
342 |
+
level_start_index.pop()
|
343 |
+
return feat_flatten, spatial_shapes, level_start_index
|
344 |
+
|
345 |
+
def _get_decoder_input(self, memory, spatial_shapes, denoising_class=None, denoising_bbox_unact=None):
|
346 |
+
bs, _, _ = memory.shape
|
347 |
+
# prepare input for decoder
|
348 |
+
anchors, valid_mask = self._generate_anchors(spatial_shapes, dtype=memory.dtype, device=memory.device)
|
349 |
+
memory = torch.where(valid_mask, memory, 0)
|
350 |
+
output_memory = self.enc_output(memory)
|
351 |
+
|
352 |
+
enc_outputs_class = self.enc_score_head(output_memory) # (bs, h*w, nc)
|
353 |
+
enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors # (bs, h*w, 4)
|
354 |
+
|
355 |
+
# (bs, topk)
|
356 |
+
_, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
|
357 |
+
# extract region proposal boxes
|
358 |
+
# (bs, topk_ind)
|
359 |
+
batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
|
360 |
+
topk_ind = topk_ind.view(-1)
|
361 |
+
|
362 |
+
# Unsigmoided
|
363 |
+
reference_points_unact = enc_outputs_coord_unact[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
364 |
+
|
365 |
+
enc_topk_bboxes = torch.sigmoid(reference_points_unact)
|
366 |
+
if denoising_bbox_unact is not None:
|
367 |
+
reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)
|
368 |
+
if self.training:
|
369 |
+
reference_points_unact = reference_points_unact.detach()
|
370 |
+
enc_topk_logits = enc_outputs_class[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
371 |
+
|
372 |
+
# extract region features
|
373 |
+
if self.learnt_init_query:
|
374 |
+
target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
|
375 |
+
else:
|
376 |
+
target = output_memory[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
377 |
+
if self.training:
|
378 |
+
target = target.detach()
|
379 |
+
if denoising_class is not None:
|
380 |
+
target = torch.concat([denoising_class, target], 1)
|
381 |
+
|
382 |
+
return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits
|
ultralytics/nn/modules/transformer.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Transformer modules
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.nn.init import constant_, xavier_uniform_
|
12 |
+
|
13 |
+
from .conv import Conv
|
14 |
+
from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch
|
15 |
+
|
16 |
+
__all__ = ('TransformerEncoderLayer', 'TransformerLayer', 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'AIFI',
|
17 |
+
'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP')
|
18 |
+
|
19 |
+
|
20 |
+
class TransformerEncoderLayer(nn.Module):
|
21 |
+
"""Transformer Encoder."""
|
22 |
+
|
23 |
+
def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False):
|
24 |
+
super().__init__()
|
25 |
+
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
|
26 |
+
# Implementation of Feedforward model
|
27 |
+
self.fc1 = nn.Linear(c1, cm)
|
28 |
+
self.fc2 = nn.Linear(cm, c1)
|
29 |
+
|
30 |
+
self.norm1 = nn.LayerNorm(c1)
|
31 |
+
self.norm2 = nn.LayerNorm(c1)
|
32 |
+
self.dropout = nn.Dropout(dropout)
|
33 |
+
self.dropout1 = nn.Dropout(dropout)
|
34 |
+
self.dropout2 = nn.Dropout(dropout)
|
35 |
+
|
36 |
+
self.act = act
|
37 |
+
self.normalize_before = normalize_before
|
38 |
+
|
39 |
+
def with_pos_embed(self, tensor, pos=None):
|
40 |
+
"""Add position embeddings if given."""
|
41 |
+
return tensor if pos is None else tensor + pos
|
42 |
+
|
43 |
+
def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
44 |
+
q = k = self.with_pos_embed(src, pos)
|
45 |
+
src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
46 |
+
src = src + self.dropout1(src2)
|
47 |
+
src = self.norm1(src)
|
48 |
+
src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
|
49 |
+
src = src + self.dropout2(src2)
|
50 |
+
src = self.norm2(src)
|
51 |
+
return src
|
52 |
+
|
53 |
+
def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
54 |
+
src2 = self.norm1(src)
|
55 |
+
q = k = self.with_pos_embed(src2, pos)
|
56 |
+
src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
57 |
+
src = src + self.dropout1(src2)
|
58 |
+
src2 = self.norm2(src)
|
59 |
+
src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
|
60 |
+
src = src + self.dropout2(src2)
|
61 |
+
return src
|
62 |
+
|
63 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
64 |
+
"""Forward propagates the input through the encoder module."""
|
65 |
+
if self.normalize_before:
|
66 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
67 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
68 |
+
|
69 |
+
|
70 |
+
class AIFI(TransformerEncoderLayer):
|
71 |
+
|
72 |
+
def __init__(self, c1, cm=2048, num_heads=8, dropout=0, act=nn.GELU(), normalize_before=False):
|
73 |
+
super().__init__(c1, cm, num_heads, dropout, act, normalize_before)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
c, h, w = x.shape[1:]
|
77 |
+
pos_embed = self.build_2d_sincos_position_embedding(w, h, c)
|
78 |
+
# flatten [B, C, H, W] to [B, HxW, C]
|
79 |
+
x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype))
|
80 |
+
return x.permute((0, 2, 1)).view([-1, c, h, w])
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.):
|
84 |
+
grid_w = torch.arange(int(w), dtype=torch.float32)
|
85 |
+
grid_h = torch.arange(int(h), dtype=torch.float32)
|
86 |
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij')
|
87 |
+
assert embed_dim % 4 == 0, \
|
88 |
+
'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
|
89 |
+
pos_dim = embed_dim // 4
|
90 |
+
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
91 |
+
omega = 1. / (temperature ** omega)
|
92 |
+
|
93 |
+
out_w = grid_w.flatten()[..., None] @ omega[None]
|
94 |
+
out_h = grid_h.flatten()[..., None] @ omega[None]
|
95 |
+
|
96 |
+
return torch.concat([torch.sin(out_w), torch.cos(out_w),
|
97 |
+
torch.sin(out_h), torch.cos(out_h)], axis=1)[None, :, :]
|
98 |
+
|
99 |
+
|
100 |
+
class TransformerLayer(nn.Module):
|
101 |
+
"""Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)."""
|
102 |
+
|
103 |
+
def __init__(self, c, num_heads):
|
104 |
+
"""Initializes a self-attention mechanism using linear transformations and multi-head attention."""
|
105 |
+
super().__init__()
|
106 |
+
self.q = nn.Linear(c, c, bias=False)
|
107 |
+
self.k = nn.Linear(c, c, bias=False)
|
108 |
+
self.v = nn.Linear(c, c, bias=False)
|
109 |
+
self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
|
110 |
+
self.fc1 = nn.Linear(c, c, bias=False)
|
111 |
+
self.fc2 = nn.Linear(c, c, bias=False)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
"""Apply a transformer block to the input x and return the output."""
|
115 |
+
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
|
116 |
+
x = self.fc2(self.fc1(x)) + x
|
117 |
+
return x
|
118 |
+
|
119 |
+
|
120 |
+
class TransformerBlock(nn.Module):
|
121 |
+
"""Vision Transformer https://arxiv.org/abs/2010.11929."""
|
122 |
+
|
123 |
+
def __init__(self, c1, c2, num_heads, num_layers):
|
124 |
+
"""Initialize a Transformer module with position embedding and specified number of heads and layers."""
|
125 |
+
super().__init__()
|
126 |
+
self.conv = None
|
127 |
+
if c1 != c2:
|
128 |
+
self.conv = Conv(c1, c2)
|
129 |
+
self.linear = nn.Linear(c2, c2) # learnable position embedding
|
130 |
+
self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
|
131 |
+
self.c2 = c2
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
"""Forward propagates the input through the bottleneck module."""
|
135 |
+
if self.conv is not None:
|
136 |
+
x = self.conv(x)
|
137 |
+
b, _, w, h = x.shape
|
138 |
+
p = x.flatten(2).permute(2, 0, 1)
|
139 |
+
return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
|
140 |
+
|
141 |
+
|
142 |
+
class MLPBlock(nn.Module):
|
143 |
+
|
144 |
+
def __init__(self, embedding_dim, mlp_dim, act=nn.GELU):
|
145 |
+
super().__init__()
|
146 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
147 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
148 |
+
self.act = act()
|
149 |
+
|
150 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
151 |
+
return self.lin2(self.act(self.lin1(x)))
|
152 |
+
|
153 |
+
|
154 |
+
class MLP(nn.Module):
|
155 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
156 |
+
|
157 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
158 |
+
super().__init__()
|
159 |
+
self.num_layers = num_layers
|
160 |
+
h = [hidden_dim] * (num_layers - 1)
|
161 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
for i, layer in enumerate(self.layers):
|
165 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
166 |
+
return x
|
167 |
+
|
168 |
+
|
169 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
170 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
171 |
+
class LayerNorm2d(nn.Module):
|
172 |
+
|
173 |
+
def __init__(self, num_channels, eps=1e-6):
|
174 |
+
super().__init__()
|
175 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
176 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
177 |
+
self.eps = eps
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
u = x.mean(1, keepdim=True)
|
181 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
182 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
183 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
184 |
+
return x
|
185 |
+
|
186 |
+
|
187 |
+
class MSDeformAttn(nn.Module):
|
188 |
+
"""
|
189 |
+
Original Multi-Scale Deformable Attention Module.
|
190 |
+
https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
|
191 |
+
"""
|
192 |
+
|
193 |
+
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
|
194 |
+
super().__init__()
|
195 |
+
if d_model % n_heads != 0:
|
196 |
+
raise ValueError(f'd_model must be divisible by n_heads, but got {d_model} and {n_heads}')
|
197 |
+
_d_per_head = d_model // n_heads
|
198 |
+
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
|
199 |
+
assert _d_per_head * n_heads == d_model, '`d_model` must be divisible by `n_heads`'
|
200 |
+
|
201 |
+
self.im2col_step = 64
|
202 |
+
|
203 |
+
self.d_model = d_model
|
204 |
+
self.n_levels = n_levels
|
205 |
+
self.n_heads = n_heads
|
206 |
+
self.n_points = n_points
|
207 |
+
|
208 |
+
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
|
209 |
+
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
|
210 |
+
self.value_proj = nn.Linear(d_model, d_model)
|
211 |
+
self.output_proj = nn.Linear(d_model, d_model)
|
212 |
+
|
213 |
+
self._reset_parameters()
|
214 |
+
|
215 |
+
def _reset_parameters(self):
|
216 |
+
constant_(self.sampling_offsets.weight.data, 0.)
|
217 |
+
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
218 |
+
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
219 |
+
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(
|
220 |
+
1, self.n_levels, self.n_points, 1)
|
221 |
+
for i in range(self.n_points):
|
222 |
+
grid_init[:, :, i, :] *= i + 1
|
223 |
+
with torch.no_grad():
|
224 |
+
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
225 |
+
constant_(self.attention_weights.weight.data, 0.)
|
226 |
+
constant_(self.attention_weights.bias.data, 0.)
|
227 |
+
xavier_uniform_(self.value_proj.weight.data)
|
228 |
+
constant_(self.value_proj.bias.data, 0.)
|
229 |
+
xavier_uniform_(self.output_proj.weight.data)
|
230 |
+
constant_(self.output_proj.bias.data, 0.)
|
231 |
+
|
232 |
+
def forward(self, query, reference_points, value, value_spatial_shapes, value_mask=None):
|
233 |
+
"""
|
234 |
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
|
235 |
+
Args:
|
236 |
+
query (Tensor): [bs, query_length, C]
|
237 |
+
reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
|
238 |
+
bottom-right (1, 1), including padding area
|
239 |
+
value (Tensor): [bs, value_length, C]
|
240 |
+
value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
241 |
+
value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
output (Tensor): [bs, Length_{query}, C]
|
245 |
+
"""
|
246 |
+
bs, len_q = query.shape[:2]
|
247 |
+
_, len_v = value.shape[:2]
|
248 |
+
assert sum(s[0] * s[1] for s in value_spatial_shapes) == len_v
|
249 |
+
|
250 |
+
value = self.value_proj(value)
|
251 |
+
if value_mask is not None:
|
252 |
+
value = value.masked_fill(value_mask[..., None], float(0))
|
253 |
+
value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads)
|
254 |
+
sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, self.n_levels, self.n_points, 2)
|
255 |
+
attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points)
|
256 |
+
attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points)
|
257 |
+
# N, Len_q, n_heads, n_levels, n_points, 2
|
258 |
+
n = reference_points.shape[-1]
|
259 |
+
if n == 2:
|
260 |
+
offset_normalizer = torch.as_tensor(value_spatial_shapes, dtype=query.dtype, device=query.device).flip(-1)
|
261 |
+
add = sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
262 |
+
sampling_locations = reference_points[:, :, None, :, None, :] + add
|
263 |
+
|
264 |
+
elif n == 4:
|
265 |
+
add = sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
|
266 |
+
sampling_locations = reference_points[:, :, None, :, None, :2] + add
|
267 |
+
else:
|
268 |
+
raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {n}.')
|
269 |
+
output = multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights)
|
270 |
+
output = self.output_proj(output)
|
271 |
+
return output
|
272 |
+
|
273 |
+
|
274 |
+
class DeformableTransformerDecoderLayer(nn.Module):
|
275 |
+
"""
|
276 |
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
|
277 |
+
https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0., act=nn.ReLU(), n_levels=4, n_points=4):
|
281 |
+
super().__init__()
|
282 |
+
|
283 |
+
# self attention
|
284 |
+
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
285 |
+
self.dropout1 = nn.Dropout(dropout)
|
286 |
+
self.norm1 = nn.LayerNorm(d_model)
|
287 |
+
|
288 |
+
# cross attention
|
289 |
+
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
290 |
+
self.dropout2 = nn.Dropout(dropout)
|
291 |
+
self.norm2 = nn.LayerNorm(d_model)
|
292 |
+
|
293 |
+
# ffn
|
294 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
295 |
+
self.act = act
|
296 |
+
self.dropout3 = nn.Dropout(dropout)
|
297 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
298 |
+
self.dropout4 = nn.Dropout(dropout)
|
299 |
+
self.norm3 = nn.LayerNorm(d_model)
|
300 |
+
|
301 |
+
@staticmethod
|
302 |
+
def with_pos_embed(tensor, pos):
|
303 |
+
return tensor if pos is None else tensor + pos
|
304 |
+
|
305 |
+
def forward_ffn(self, tgt):
|
306 |
+
tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))
|
307 |
+
tgt = tgt + self.dropout4(tgt2)
|
308 |
+
tgt = self.norm3(tgt)
|
309 |
+
return tgt
|
310 |
+
|
311 |
+
def forward(self,
|
312 |
+
tgt,
|
313 |
+
reference_points,
|
314 |
+
src,
|
315 |
+
src_spatial_shapes,
|
316 |
+
src_padding_mask=None,
|
317 |
+
attn_mask=None,
|
318 |
+
query_pos=None):
|
319 |
+
# self attention
|
320 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
321 |
+
if attn_mask is not None:
|
322 |
+
attn_mask = torch.where(attn_mask.astype('bool'), torch.zeros(attn_mask.shape, tgt.dtype),
|
323 |
+
torch.full(attn_mask.shape, float('-inf'), tgt.dtype))
|
324 |
+
tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
|
325 |
+
tgt = tgt + self.dropout1(tgt2)
|
326 |
+
tgt = self.norm1(tgt)
|
327 |
+
|
328 |
+
# cross attention
|
329 |
+
tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), reference_points, src, src_spatial_shapes,
|
330 |
+
src_padding_mask)
|
331 |
+
tgt = tgt + self.dropout2(tgt2)
|
332 |
+
tgt = self.norm2(tgt)
|
333 |
+
|
334 |
+
# ffn
|
335 |
+
tgt = self.forward_ffn(tgt)
|
336 |
+
|
337 |
+
return tgt
|
338 |
+
|
339 |
+
|
340 |
+
class DeformableTransformerDecoder(nn.Module):
|
341 |
+
"""
|
342 |
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
|
343 |
+
"""
|
344 |
+
|
345 |
+
def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1):
|
346 |
+
super().__init__()
|
347 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
348 |
+
self.num_layers = num_layers
|
349 |
+
self.hidden_dim = hidden_dim
|
350 |
+
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
|
351 |
+
|
352 |
+
def forward(self,
|
353 |
+
tgt,
|
354 |
+
reference_points,
|
355 |
+
src,
|
356 |
+
src_spatial_shapes,
|
357 |
+
bbox_head,
|
358 |
+
score_head,
|
359 |
+
query_pos_head,
|
360 |
+
attn_mask=None,
|
361 |
+
src_padding_mask=None):
|
362 |
+
output = tgt
|
363 |
+
dec_out_bboxes = []
|
364 |
+
dec_out_logits = []
|
365 |
+
ref_points = None
|
366 |
+
ref_points_detach = torch.sigmoid(reference_points)
|
367 |
+
for i, layer in enumerate(self.layers):
|
368 |
+
ref_points_input = ref_points_detach.unsqueeze(2)
|
369 |
+
query_pos_embed = query_pos_head(ref_points_detach)
|
370 |
+
output = layer(output, ref_points_input, src, src_spatial_shapes, src_padding_mask, attn_mask,
|
371 |
+
query_pos_embed)
|
372 |
+
|
373 |
+
inter_ref_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))
|
374 |
+
|
375 |
+
if self.training:
|
376 |
+
dec_out_logits.append(score_head[i](output))
|
377 |
+
if i == 0:
|
378 |
+
dec_out_bboxes.append(inter_ref_bbox)
|
379 |
+
else:
|
380 |
+
dec_out_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)))
|
381 |
+
elif i == self.eval_idx:
|
382 |
+
dec_out_logits.append(score_head[i](output))
|
383 |
+
dec_out_bboxes.append(inter_ref_bbox)
|
384 |
+
break
|
385 |
+
|
386 |
+
ref_points = inter_ref_bbox
|
387 |
+
ref_points_detach = inter_ref_bbox.detach() if self.training else inter_ref_bbox
|
388 |
+
|
389 |
+
return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
|
ultralytics/nn/modules/utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Module utils
|
4 |
+
"""
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import math
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.nn.init import uniform_
|
14 |
+
|
15 |
+
__all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid'
|
16 |
+
|
17 |
+
|
18 |
+
def _get_clones(module, n):
|
19 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
|
20 |
+
|
21 |
+
|
22 |
+
def bias_init_with_prob(prior_prob=0.01):
|
23 |
+
"""initialize conv/fc bias value according to a given probability value."""
|
24 |
+
return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init
|
25 |
+
|
26 |
+
|
27 |
+
def linear_init_(module):
|
28 |
+
bound = 1 / math.sqrt(module.weight.shape[0])
|
29 |
+
uniform_(module.weight, -bound, bound)
|
30 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
31 |
+
uniform_(module.bias, -bound, bound)
|
32 |
+
|
33 |
+
|
34 |
+
def inverse_sigmoid(x, eps=1e-5):
|
35 |
+
x = x.clamp(min=0, max=1)
|
36 |
+
x1 = x.clamp(min=eps)
|
37 |
+
x2 = (1 - x).clamp(min=eps)
|
38 |
+
return torch.log(x1 / x2)
|
39 |
+
|
40 |
+
|
41 |
+
def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor,
|
42 |
+
sampling_locations: torch.Tensor,
|
43 |
+
attention_weights: torch.Tensor) -> torch.Tensor:
|
44 |
+
"""
|
45 |
+
Multi-scale deformable attention.
|
46 |
+
https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py
|
47 |
+
"""
|
48 |
+
|
49 |
+
bs, _, num_heads, embed_dims = value.shape
|
50 |
+
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
51 |
+
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
52 |
+
sampling_grids = 2 * sampling_locations - 1
|
53 |
+
sampling_value_list = []
|
54 |
+
for level, (H_, W_) in enumerate(value_spatial_shapes):
|
55 |
+
# bs, H_*W_, num_heads, embed_dims ->
|
56 |
+
# bs, H_*W_, num_heads*embed_dims ->
|
57 |
+
# bs, num_heads*embed_dims, H_*W_ ->
|
58 |
+
# bs*num_heads, embed_dims, H_, W_
|
59 |
+
value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_))
|
60 |
+
# bs, num_queries, num_heads, num_points, 2 ->
|
61 |
+
# bs, num_heads, num_queries, num_points, 2 ->
|
62 |
+
# bs*num_heads, num_queries, num_points, 2
|
63 |
+
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
|
64 |
+
# bs*num_heads, embed_dims, num_queries, num_points
|
65 |
+
sampling_value_l_ = F.grid_sample(value_l_,
|
66 |
+
sampling_grid_l_,
|
67 |
+
mode='bilinear',
|
68 |
+
padding_mode='zeros',
|
69 |
+
align_corners=False)
|
70 |
+
sampling_value_list.append(sampling_value_l_)
|
71 |
+
# (bs, num_queries, num_heads, num_levels, num_points) ->
|
72 |
+
# (bs, num_heads, num_queries, num_levels, num_points) ->
|
73 |
+
# (bs, num_heads, 1, num_queries, num_levels*num_points)
|
74 |
+
attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries,
|
75 |
+
num_levels * num_points)
|
76 |
+
output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(
|
77 |
+
bs, num_heads * embed_dims, num_queries))
|
78 |
+
return output.transpose(1, 2).contiguous()
|
ultralytics/nn/tasks.py
ADDED
@@ -0,0 +1,773 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import contextlib
|
4 |
+
from copy import deepcopy
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
|
11 |
+
Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
|
12 |
+
Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
|
13 |
+
RTDETRDecoder, Segment)
|
14 |
+
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
15 |
+
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
|
16 |
+
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
|
17 |
+
from ultralytics.yolo.utils.plotting import feature_visualization
|
18 |
+
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
|
19 |
+
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
|
20 |
+
|
21 |
+
try:
|
22 |
+
import thop
|
23 |
+
except ImportError:
|
24 |
+
thop = None
|
25 |
+
|
26 |
+
|
27 |
+
class BaseModel(nn.Module):
|
28 |
+
"""
|
29 |
+
The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def forward(self, x, *args, **kwargs):
|
33 |
+
"""
|
34 |
+
Forward pass of the model on a single scale.
|
35 |
+
Wrapper for `_forward_once` method.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
(torch.Tensor): The output of the network.
|
42 |
+
"""
|
43 |
+
if isinstance(x, dict): # for cases of training and validating while training.
|
44 |
+
return self.loss(x, *args, **kwargs)
|
45 |
+
return self.predict(x, *args, **kwargs)
|
46 |
+
|
47 |
+
def predict(self, x, profile=False, visualize=False, augment=False):
|
48 |
+
"""
|
49 |
+
Perform a forward pass through the network.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
x (torch.Tensor): The input tensor to the model.
|
53 |
+
profile (bool): Print the computation time of each layer if True, defaults to False.
|
54 |
+
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
55 |
+
augment (bool): Augment image during prediction, defaults to False.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
(torch.Tensor): The last output of the model.
|
59 |
+
"""
|
60 |
+
if augment:
|
61 |
+
return self._predict_augment(x)
|
62 |
+
return self._predict_once(x, profile, visualize)
|
63 |
+
|
64 |
+
def _predict_once(self, x, profile=False, visualize=False):
|
65 |
+
"""
|
66 |
+
Perform a forward pass through the network.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
x (torch.Tensor): The input tensor to the model.
|
70 |
+
profile (bool): Print the computation time of each layer if True, defaults to False.
|
71 |
+
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
(torch.Tensor): The last output of the model.
|
75 |
+
"""
|
76 |
+
y, dt = [], [] # outputs
|
77 |
+
for m in self.model:
|
78 |
+
if m.f != -1: # if not from previous layer
|
79 |
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
80 |
+
if profile:
|
81 |
+
self._profile_one_layer(m, x, dt)
|
82 |
+
x = m(x) # run
|
83 |
+
y.append(x if m.i in self.save else None) # save output
|
84 |
+
if visualize:
|
85 |
+
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
86 |
+
return x
|
87 |
+
|
88 |
+
def _predict_augment(self, x):
|
89 |
+
"""Perform augmentations on input image x and return augmented inference."""
|
90 |
+
LOGGER.warning(
|
91 |
+
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
92 |
+
)
|
93 |
+
return self._predict_once(x)
|
94 |
+
|
95 |
+
def _profile_one_layer(self, m, x, dt):
|
96 |
+
"""
|
97 |
+
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
98 |
+
Appends the results to the provided list.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
m (nn.Module): The layer to be profiled.
|
102 |
+
x (torch.Tensor): The input data to the layer.
|
103 |
+
dt (list): A list to store the computation time of the layer.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
None
|
107 |
+
"""
|
108 |
+
c = m == self.model[-1] # is final layer, copy input as inplace fix
|
109 |
+
o = thop.profile(m, inputs=[x.clone() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
110 |
+
t = time_sync()
|
111 |
+
for _ in range(10):
|
112 |
+
m(x.clone() if c else x)
|
113 |
+
dt.append((time_sync() - t) * 100)
|
114 |
+
if m == self.model[0]:
|
115 |
+
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
116 |
+
LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
|
117 |
+
if c:
|
118 |
+
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
119 |
+
|
120 |
+
def fuse(self, verbose=True):
|
121 |
+
"""
|
122 |
+
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
|
123 |
+
computation efficiency.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
(nn.Module): The fused model is returned.
|
127 |
+
"""
|
128 |
+
if not self.is_fused():
|
129 |
+
for m in self.model.modules():
|
130 |
+
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
|
131 |
+
if isinstance(m, Conv2):
|
132 |
+
m.fuse_convs()
|
133 |
+
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
134 |
+
delattr(m, 'bn') # remove batchnorm
|
135 |
+
m.forward = m.forward_fuse # update forward
|
136 |
+
if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
|
137 |
+
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
|
138 |
+
delattr(m, 'bn') # remove batchnorm
|
139 |
+
m.forward = m.forward_fuse # update forward
|
140 |
+
if isinstance(m, RepConv):
|
141 |
+
m.fuse_convs()
|
142 |
+
m.forward = m.forward_fuse # update forward
|
143 |
+
self.info(verbose=verbose)
|
144 |
+
|
145 |
+
return self
|
146 |
+
|
147 |
+
def is_fused(self, thresh=10):
|
148 |
+
"""
|
149 |
+
Check if the model has less than a certain threshold of BatchNorm layers.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
156 |
+
"""
|
157 |
+
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
158 |
+
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
159 |
+
|
160 |
+
def info(self, detailed=False, verbose=True, imgsz=640):
|
161 |
+
"""
|
162 |
+
Prints model information
|
163 |
+
|
164 |
+
Args:
|
165 |
+
verbose (bool): if True, prints out the model information. Defaults to False
|
166 |
+
imgsz (int): the size of the image that the model will be trained on. Defaults to 640
|
167 |
+
"""
|
168 |
+
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
|
169 |
+
|
170 |
+
def _apply(self, fn):
|
171 |
+
"""
|
172 |
+
`_apply()` is a function that applies a function to all the tensors in the model that are not
|
173 |
+
parameters or registered buffers
|
174 |
+
|
175 |
+
Args:
|
176 |
+
fn: the function to apply to the model
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
A model that is a Detect() object.
|
180 |
+
"""
|
181 |
+
self = super()._apply(fn)
|
182 |
+
m = self.model[-1] # Detect()
|
183 |
+
if isinstance(m, (Detect, Segment)):
|
184 |
+
m.stride = fn(m.stride)
|
185 |
+
m.anchors = fn(m.anchors)
|
186 |
+
m.strides = fn(m.strides)
|
187 |
+
return self
|
188 |
+
|
189 |
+
def load(self, weights, verbose=True):
|
190 |
+
"""Load the weights into the model.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
weights (dict) or (torch.nn.Module): The pre-trained weights to be loaded.
|
194 |
+
verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
|
195 |
+
"""
|
196 |
+
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
197 |
+
csd = model.float().state_dict() # checkpoint state_dict as FP32
|
198 |
+
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
199 |
+
self.load_state_dict(csd, strict=False) # load
|
200 |
+
if verbose:
|
201 |
+
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
|
202 |
+
|
203 |
+
def loss(self, batch, preds=None):
|
204 |
+
"""
|
205 |
+
Compute loss
|
206 |
+
|
207 |
+
Args:
|
208 |
+
batch (dict): Batch to compute loss on
|
209 |
+
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
210 |
+
"""
|
211 |
+
if not hasattr(self, 'criterion'):
|
212 |
+
self.criterion = self.init_criterion()
|
213 |
+
return self.criterion(self.predict(batch['img']) if preds is None else preds, batch)
|
214 |
+
|
215 |
+
def init_criterion(self):
|
216 |
+
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
|
217 |
+
|
218 |
+
|
219 |
+
class DetectionModel(BaseModel):
|
220 |
+
"""YOLOv8 detection model."""
|
221 |
+
|
222 |
+
def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
|
223 |
+
super().__init__()
|
224 |
+
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
225 |
+
|
226 |
+
# Define model
|
227 |
+
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
228 |
+
if nc and nc != self.yaml['nc']:
|
229 |
+
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
230 |
+
self.yaml['nc'] = nc # override yaml value
|
231 |
+
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
232 |
+
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
233 |
+
self.inplace = self.yaml.get('inplace', True)
|
234 |
+
|
235 |
+
# Build strides
|
236 |
+
m = self.model[-1] # Detect()
|
237 |
+
if isinstance(m, (Detect, Segment, Pose)):
|
238 |
+
s = 256 # 2x min stride
|
239 |
+
m.inplace = self.inplace
|
240 |
+
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose)) else self.forward(x)
|
241 |
+
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
|
242 |
+
self.stride = m.stride
|
243 |
+
m.bias_init() # only run once
|
244 |
+
|
245 |
+
# Init weights, biases
|
246 |
+
initialize_weights(self)
|
247 |
+
if verbose:
|
248 |
+
self.info()
|
249 |
+
LOGGER.info('')
|
250 |
+
|
251 |
+
def _predict_augment(self, x):
|
252 |
+
"""Perform augmentations on input image x and return augmented inference and train outputs."""
|
253 |
+
img_size = x.shape[-2:] # height, width
|
254 |
+
s = [1, 0.83, 0.67] # scales
|
255 |
+
f = [None, 3, None] # flips (2-ud, 3-lr)
|
256 |
+
y = [] # outputs
|
257 |
+
for si, fi in zip(s, f):
|
258 |
+
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
|
259 |
+
yi = super().predict(xi)[0] # forward
|
260 |
+
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
261 |
+
yi = self._descale_pred(yi, fi, si, img_size)
|
262 |
+
y.append(yi)
|
263 |
+
y = self._clip_augmented(y) # clip augmented tails
|
264 |
+
return torch.cat(y, -1), None # augmented inference, train
|
265 |
+
|
266 |
+
@staticmethod
|
267 |
+
def _descale_pred(p, flips, scale, img_size, dim=1):
|
268 |
+
"""De-scale predictions following augmented inference (inverse operation)."""
|
269 |
+
p[:, :4] /= scale # de-scale
|
270 |
+
x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
|
271 |
+
if flips == 2:
|
272 |
+
y = img_size[0] - y # de-flip ud
|
273 |
+
elif flips == 3:
|
274 |
+
x = img_size[1] - x # de-flip lr
|
275 |
+
return torch.cat((x, y, wh, cls), dim)
|
276 |
+
|
277 |
+
def _clip_augmented(self, y):
|
278 |
+
"""Clip YOLOv5 augmented inference tails."""
|
279 |
+
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
280 |
+
g = sum(4 ** x for x in range(nl)) # grid points
|
281 |
+
e = 1 # exclude layer count
|
282 |
+
i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices
|
283 |
+
y[0] = y[0][..., :-i] # large
|
284 |
+
i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
|
285 |
+
y[-1] = y[-1][..., i:] # small
|
286 |
+
return y
|
287 |
+
|
288 |
+
def init_criterion(self):
|
289 |
+
return v8DetectionLoss(self)
|
290 |
+
|
291 |
+
|
292 |
+
class SegmentationModel(DetectionModel):
|
293 |
+
"""YOLOv8 segmentation model."""
|
294 |
+
|
295 |
+
def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
|
296 |
+
"""Initialize YOLOv8 segmentation model with given config and parameters."""
|
297 |
+
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
298 |
+
|
299 |
+
def init_criterion(self):
|
300 |
+
return v8SegmentationLoss(self)
|
301 |
+
|
302 |
+
def _predict_augment(self, x):
|
303 |
+
"""Perform augmentations on input image x and return augmented inference."""
|
304 |
+
LOGGER.warning(
|
305 |
+
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
306 |
+
)
|
307 |
+
return self._predict_once(x)
|
308 |
+
|
309 |
+
|
310 |
+
class PoseModel(DetectionModel):
|
311 |
+
"""YOLOv8 pose model."""
|
312 |
+
|
313 |
+
def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
314 |
+
"""Initialize YOLOv8 Pose model."""
|
315 |
+
if not isinstance(cfg, dict):
|
316 |
+
cfg = yaml_model_load(cfg) # load model YAML
|
317 |
+
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']):
|
318 |
+
LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
|
319 |
+
cfg['kpt_shape'] = data_kpt_shape
|
320 |
+
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
321 |
+
|
322 |
+
def init_criterion(self):
|
323 |
+
return v8PoseLoss(self)
|
324 |
+
|
325 |
+
def _predict_augment(self, x):
|
326 |
+
"""Perform augmentations on input image x and return augmented inference."""
|
327 |
+
LOGGER.warning(
|
328 |
+
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
329 |
+
)
|
330 |
+
return self._predict_once(x)
|
331 |
+
|
332 |
+
|
333 |
+
class ClassificationModel(BaseModel):
|
334 |
+
"""YOLOv8 classification model."""
|
335 |
+
|
336 |
+
def __init__(self,
|
337 |
+
cfg=None,
|
338 |
+
model=None,
|
339 |
+
ch=3,
|
340 |
+
nc=None,
|
341 |
+
cutoff=10,
|
342 |
+
verbose=True): # yaml, model, channels, number of classes, cutoff index, verbose flag
|
343 |
+
super().__init__()
|
344 |
+
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
|
345 |
+
|
346 |
+
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
347 |
+
"""Create a YOLOv5 classification model from a YOLOv5 detection model."""
|
348 |
+
from ultralytics.nn.autobackend import AutoBackend
|
349 |
+
if isinstance(model, AutoBackend):
|
350 |
+
model = model.model # unwrap DetectMultiBackend
|
351 |
+
model.model = model.model[:cutoff] # backbone
|
352 |
+
m = model.model[-1] # last layer
|
353 |
+
ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
|
354 |
+
c = Classify(ch, nc) # Classify()
|
355 |
+
c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
|
356 |
+
model.model[-1] = c # replace
|
357 |
+
self.model = model.model
|
358 |
+
self.stride = model.stride
|
359 |
+
self.save = []
|
360 |
+
self.nc = nc
|
361 |
+
|
362 |
+
def _from_yaml(self, cfg, ch, nc, verbose):
|
363 |
+
"""Set YOLOv8 model configurations and define the model architecture."""
|
364 |
+
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
365 |
+
|
366 |
+
# Define model
|
367 |
+
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
368 |
+
if nc and nc != self.yaml['nc']:
|
369 |
+
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
370 |
+
self.yaml['nc'] = nc # override yaml value
|
371 |
+
elif not nc and not self.yaml.get('nc', None):
|
372 |
+
raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.')
|
373 |
+
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
374 |
+
self.stride = torch.Tensor([1]) # no stride constraints
|
375 |
+
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
376 |
+
self.info()
|
377 |
+
|
378 |
+
@staticmethod
|
379 |
+
def reshape_outputs(model, nc):
|
380 |
+
"""Update a TorchVision classification model to class count 'n' if required."""
|
381 |
+
name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
|
382 |
+
if isinstance(m, Classify): # YOLO Classify() head
|
383 |
+
if m.linear.out_features != nc:
|
384 |
+
m.linear = nn.Linear(m.linear.in_features, nc)
|
385 |
+
elif isinstance(m, nn.Linear): # ResNet, EfficientNet
|
386 |
+
if m.out_features != nc:
|
387 |
+
setattr(model, name, nn.Linear(m.in_features, nc))
|
388 |
+
elif isinstance(m, nn.Sequential):
|
389 |
+
types = [type(x) for x in m]
|
390 |
+
if nn.Linear in types:
|
391 |
+
i = types.index(nn.Linear) # nn.Linear index
|
392 |
+
if m[i].out_features != nc:
|
393 |
+
m[i] = nn.Linear(m[i].in_features, nc)
|
394 |
+
elif nn.Conv2d in types:
|
395 |
+
i = types.index(nn.Conv2d) # nn.Conv2d index
|
396 |
+
if m[i].out_channels != nc:
|
397 |
+
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
|
398 |
+
|
399 |
+
def init_criterion(self):
|
400 |
+
"""Compute the classification loss between predictions and true labels."""
|
401 |
+
return v8ClassificationLoss()
|
402 |
+
|
403 |
+
|
404 |
+
class RTDETRDetectionModel(DetectionModel):
|
405 |
+
|
406 |
+
def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True):
|
407 |
+
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
408 |
+
|
409 |
+
def init_criterion(self):
|
410 |
+
"""Compute the classification loss between predictions and true labels."""
|
411 |
+
from ultralytics.vit.utils.loss import RTDETRDetectionLoss
|
412 |
+
|
413 |
+
return RTDETRDetectionLoss(num_classes=self.nc, use_vfl=True)
|
414 |
+
|
415 |
+
def loss(self, batch, preds=None):
|
416 |
+
if not hasattr(self, 'criterion'):
|
417 |
+
self.criterion = self.init_criterion()
|
418 |
+
|
419 |
+
img = batch['img']
|
420 |
+
# NOTE: preprocess gt_bbox and gt_labels to list.
|
421 |
+
bs = len(img)
|
422 |
+
batch_idx = batch['batch_idx']
|
423 |
+
gt_bbox, gt_class = [], []
|
424 |
+
for i in range(bs):
|
425 |
+
gt_bbox.append(batch['bboxes'][batch_idx == i].to(img.device))
|
426 |
+
gt_class.append(batch['cls'][batch_idx == i].to(device=img.device, dtype=torch.long))
|
427 |
+
targets = {'cls': gt_class, 'bboxes': gt_bbox}
|
428 |
+
|
429 |
+
preds = self.predict(img, batch=targets) if preds is None else preds
|
430 |
+
dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = preds
|
431 |
+
# NOTE: `dn_meta` means it's eval mode, loss calculation for eval mode is not supported.
|
432 |
+
if dn_meta is None:
|
433 |
+
return 0, torch.zeros(3, device=dec_out_bboxes.device)
|
434 |
+
dn_out_bboxes, dec_out_bboxes = torch.split(dec_out_bboxes, dn_meta['dn_num_split'], dim=2)
|
435 |
+
dn_out_logits, dec_out_logits = torch.split(dec_out_logits, dn_meta['dn_num_split'], dim=2)
|
436 |
+
|
437 |
+
out_bboxes = torch.cat([enc_topk_bboxes.unsqueeze(0), dec_out_bboxes])
|
438 |
+
out_logits = torch.cat([enc_topk_logits.unsqueeze(0), dec_out_logits])
|
439 |
+
|
440 |
+
loss = self.criterion((out_bboxes, out_logits),
|
441 |
+
targets,
|
442 |
+
dn_out_bboxes=dn_out_bboxes,
|
443 |
+
dn_out_logits=dn_out_logits,
|
444 |
+
dn_meta=dn_meta)
|
445 |
+
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']])
|
446 |
+
|
447 |
+
def predict(self, x, profile=False, visualize=False, batch=None):
|
448 |
+
"""
|
449 |
+
Perform a forward pass through the network.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
x (torch.Tensor): The input tensor to the model
|
453 |
+
profile (bool): Print the computation time of each layer if True, defaults to False.
|
454 |
+
visualize (bool): Save the feature maps of the model if True, defaults to False
|
455 |
+
batch (dict): A dict including gt boxes and labels from dataloader.
|
456 |
+
|
457 |
+
Returns:
|
458 |
+
(torch.Tensor): The last output of the model.
|
459 |
+
"""
|
460 |
+
y, dt = [], [] # outputs
|
461 |
+
for m in self.model[:-1]: # except the head part
|
462 |
+
if m.f != -1: # if not from previous layer
|
463 |
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
464 |
+
if profile:
|
465 |
+
self._profile_one_layer(m, x, dt)
|
466 |
+
x = m(x) # run
|
467 |
+
y.append(x if m.i in self.save else None) # save output
|
468 |
+
if visualize:
|
469 |
+
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
470 |
+
head = self.model[-1]
|
471 |
+
x = head([y[j] for j in head.f], batch) # head inference
|
472 |
+
return x
|
473 |
+
|
474 |
+
|
475 |
+
class Ensemble(nn.ModuleList):
|
476 |
+
"""Ensemble of models."""
|
477 |
+
|
478 |
+
def __init__(self):
|
479 |
+
"""Initialize an ensemble of models."""
|
480 |
+
super().__init__()
|
481 |
+
|
482 |
+
def forward(self, x, augment=False, profile=False, visualize=False):
|
483 |
+
"""Function generates the YOLOv5 network's final layer."""
|
484 |
+
y = [module(x, augment, profile, visualize)[0] for module in self]
|
485 |
+
# y = torch.stack(y).max(0)[0] # max ensemble
|
486 |
+
# y = torch.stack(y).mean(0) # mean ensemble
|
487 |
+
y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C)
|
488 |
+
return y, None # inference, train output
|
489 |
+
|
490 |
+
|
491 |
+
# Functions ------------------------------------------------------------------------------------------------------------
|
492 |
+
|
493 |
+
|
494 |
+
def torch_safe_load(weight):
|
495 |
+
"""
|
496 |
+
This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
|
497 |
+
it catches the error, logs a warning message, and attempts to install the missing module via the
|
498 |
+
check_requirements() function. After installation, the function again attempts to load the model using torch.load().
|
499 |
+
|
500 |
+
Args:
|
501 |
+
weight (str): The file path of the PyTorch model.
|
502 |
+
|
503 |
+
Returns:
|
504 |
+
(dict): The loaded PyTorch model.
|
505 |
+
"""
|
506 |
+
from ultralytics.yolo.utils.downloads import attempt_download_asset
|
507 |
+
|
508 |
+
check_suffix(file=weight, suffix='.pt')
|
509 |
+
file = attempt_download_asset(weight) # search online if missing locally
|
510 |
+
try:
|
511 |
+
return torch.load(file, map_location='cpu'), file # load
|
512 |
+
except ModuleNotFoundError as e: # e.name is missing module name
|
513 |
+
if e.name == 'models':
|
514 |
+
raise TypeError(
|
515 |
+
emojis(f'ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained '
|
516 |
+
f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with '
|
517 |
+
f'YOLOv8 at https://github.com/ultralytics/ultralytics.'
|
518 |
+
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
519 |
+
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e
|
520 |
+
LOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
|
521 |
+
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
522 |
+
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
523 |
+
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
|
524 |
+
check_requirements(e.name) # install missing module
|
525 |
+
|
526 |
+
return torch.load(file, map_location='cpu'), file # load
|
527 |
+
|
528 |
+
|
529 |
+
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
530 |
+
"""Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
|
531 |
+
|
532 |
+
ensemble = Ensemble()
|
533 |
+
for w in weights if isinstance(weights, list) else [weights]:
|
534 |
+
ckpt, w = torch_safe_load(w) # load ckpt
|
535 |
+
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} if 'train_args' in ckpt else None # combined args
|
536 |
+
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
537 |
+
|
538 |
+
# Model compatibility updates
|
539 |
+
model.args = args # attach args to model
|
540 |
+
model.pt_path = w # attach *.pt file path to model
|
541 |
+
model.task = guess_model_task(model)
|
542 |
+
if not hasattr(model, 'stride'):
|
543 |
+
model.stride = torch.tensor([32.])
|
544 |
+
|
545 |
+
# Append
|
546 |
+
ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode
|
547 |
+
|
548 |
+
# Module compatibility updates
|
549 |
+
for m in ensemble.modules():
|
550 |
+
t = type(m)
|
551 |
+
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
|
552 |
+
m.inplace = inplace # torch 1.7.0 compatibility
|
553 |
+
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
554 |
+
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
555 |
+
|
556 |
+
# Return model
|
557 |
+
if len(ensemble) == 1:
|
558 |
+
return ensemble[-1]
|
559 |
+
|
560 |
+
# Return ensemble
|
561 |
+
LOGGER.info(f'Ensemble created with {weights}\n')
|
562 |
+
for k in 'names', 'nc', 'yaml':
|
563 |
+
setattr(ensemble, k, getattr(ensemble[0], k))
|
564 |
+
ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
|
565 |
+
assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}'
|
566 |
+
return ensemble
|
567 |
+
|
568 |
+
|
569 |
+
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
570 |
+
"""Loads a single model weights."""
|
571 |
+
ckpt, weight = torch_safe_load(weight) # load ckpt
|
572 |
+
args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args
|
573 |
+
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
574 |
+
|
575 |
+
# Model compatibility updates
|
576 |
+
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
577 |
+
model.pt_path = weight # attach *.pt file path to model
|
578 |
+
model.task = guess_model_task(model)
|
579 |
+
if not hasattr(model, 'stride'):
|
580 |
+
model.stride = torch.tensor([32.])
|
581 |
+
|
582 |
+
model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
|
583 |
+
|
584 |
+
# Module compatibility updates
|
585 |
+
for m in model.modules():
|
586 |
+
t = type(m)
|
587 |
+
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
|
588 |
+
m.inplace = inplace # torch 1.7.0 compatibility
|
589 |
+
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
590 |
+
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
591 |
+
|
592 |
+
# Return model and ckpt
|
593 |
+
return model, ckpt
|
594 |
+
|
595 |
+
|
596 |
+
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
597 |
+
# Parse a YOLO model.yaml dictionary into a PyTorch model
|
598 |
+
import ast
|
599 |
+
|
600 |
+
# Args
|
601 |
+
max_channels = float('inf')
|
602 |
+
nc, act, scales = (d.get(x) for x in ('nc', 'act', 'scales'))
|
603 |
+
depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
|
604 |
+
if scales:
|
605 |
+
scale = d.get('scale')
|
606 |
+
if not scale:
|
607 |
+
scale = tuple(scales.keys())[0]
|
608 |
+
LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
|
609 |
+
depth, width, max_channels = scales[scale]
|
610 |
+
|
611 |
+
if act:
|
612 |
+
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
613 |
+
if verbose:
|
614 |
+
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
615 |
+
|
616 |
+
if verbose:
|
617 |
+
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
618 |
+
ch = [ch]
|
619 |
+
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
620 |
+
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
621 |
+
m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
|
622 |
+
for j, a in enumerate(args):
|
623 |
+
if isinstance(a, str):
|
624 |
+
with contextlib.suppress(ValueError):
|
625 |
+
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
626 |
+
|
627 |
+
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
628 |
+
if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
|
629 |
+
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):
|
630 |
+
c1, c2 = ch[f], args[0]
|
631 |
+
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
632 |
+
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
633 |
+
|
634 |
+
args = [c1, c2, *args[1:]]
|
635 |
+
if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3):
|
636 |
+
args.insert(2, n) # number of repeats
|
637 |
+
n = 1
|
638 |
+
elif m is AIFI:
|
639 |
+
args = [ch[f], *args]
|
640 |
+
elif m in (HGStem, HGBlock):
|
641 |
+
c1, cm, c2 = ch[f], args[0], args[1]
|
642 |
+
args = [c1, cm, c2, *args[2:]]
|
643 |
+
if m is HGBlock:
|
644 |
+
args.insert(4, n) # number of repeats
|
645 |
+
n = 1
|
646 |
+
|
647 |
+
elif m is nn.BatchNorm2d:
|
648 |
+
args = [ch[f]]
|
649 |
+
elif m is Concat:
|
650 |
+
c2 = sum(ch[x] for x in f)
|
651 |
+
elif m in (Detect, Segment, Pose, RTDETRDecoder):
|
652 |
+
args.append([ch[x] for x in f])
|
653 |
+
if m is Segment:
|
654 |
+
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
655 |
+
else:
|
656 |
+
c2 = ch[f]
|
657 |
+
|
658 |
+
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
659 |
+
t = str(m)[8:-2].replace('__main__.', '') # module type
|
660 |
+
m.np = sum(x.numel() for x in m_.parameters()) # number params
|
661 |
+
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
662 |
+
if verbose:
|
663 |
+
LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
|
664 |
+
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
665 |
+
layers.append(m_)
|
666 |
+
if i == 0:
|
667 |
+
ch = []
|
668 |
+
ch.append(c2)
|
669 |
+
return nn.Sequential(*layers), sorted(save)
|
670 |
+
|
671 |
+
|
672 |
+
def yaml_model_load(path):
|
673 |
+
"""Load a YOLOv8 model from a YAML file."""
|
674 |
+
import re
|
675 |
+
|
676 |
+
path = Path(path)
|
677 |
+
if path.stem in (f'yolov{d}{x}6' for x in 'nsmlx' for d in (5, 8)):
|
678 |
+
new_stem = re.sub(r'(\d+)([nslmx])6(.+)?$', r'\1\2-p6\3', path.stem)
|
679 |
+
LOGGER.warning(f'WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.')
|
680 |
+
path = path.with_stem(new_stem)
|
681 |
+
|
682 |
+
unified_path = re.sub(r'(\d+)([nslmx])(.+)?$', r'\1\3', str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
|
683 |
+
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
|
684 |
+
d = yaml_load(yaml_file) # model dict
|
685 |
+
d['scale'] = guess_model_scale(path)
|
686 |
+
d['yaml_file'] = str(path)
|
687 |
+
return d
|
688 |
+
|
689 |
+
|
690 |
+
def guess_model_scale(model_path):
|
691 |
+
"""
|
692 |
+
Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale.
|
693 |
+
The function uses regular expression matching to find the pattern of the model scale in the YAML file name,
|
694 |
+
which is denoted by n, s, m, l, or x. The function returns the size character of the model scale as a string.
|
695 |
+
|
696 |
+
Args:
|
697 |
+
model_path (str) or (Path): The path to the YOLO model's YAML file.
|
698 |
+
|
699 |
+
Returns:
|
700 |
+
(str): The size character of the model's scale, which can be n, s, m, l, or x.
|
701 |
+
"""
|
702 |
+
with contextlib.suppress(AttributeError):
|
703 |
+
import re
|
704 |
+
return re.search(r'yolov\d+([nslmx])', Path(model_path).stem).group(1) # n, s, m, l, or x
|
705 |
+
return ''
|
706 |
+
|
707 |
+
|
708 |
+
def guess_model_task(model):
|
709 |
+
"""
|
710 |
+
Guess the task of a PyTorch model from its architecture or configuration.
|
711 |
+
|
712 |
+
Args:
|
713 |
+
model (nn.Module) or (dict): PyTorch model or model configuration in YAML format.
|
714 |
+
|
715 |
+
Returns:
|
716 |
+
(str): Task of the model ('detect', 'segment', 'classify', 'pose').
|
717 |
+
|
718 |
+
Raises:
|
719 |
+
SyntaxError: If the task of the model could not be determined.
|
720 |
+
"""
|
721 |
+
|
722 |
+
def cfg2task(cfg):
|
723 |
+
"""Guess from YAML dictionary."""
|
724 |
+
m = cfg['head'][-1][-2].lower() # output module name
|
725 |
+
if m in ('classify', 'classifier', 'cls', 'fc'):
|
726 |
+
return 'classify'
|
727 |
+
if m == 'detect':
|
728 |
+
return 'detect'
|
729 |
+
if m == 'segment':
|
730 |
+
return 'segment'
|
731 |
+
if m == 'pose':
|
732 |
+
return 'pose'
|
733 |
+
|
734 |
+
# Guess from model cfg
|
735 |
+
if isinstance(model, dict):
|
736 |
+
with contextlib.suppress(Exception):
|
737 |
+
return cfg2task(model)
|
738 |
+
|
739 |
+
# Guess from PyTorch model
|
740 |
+
if isinstance(model, nn.Module): # PyTorch model
|
741 |
+
for x in 'model.args', 'model.model.args', 'model.model.model.args':
|
742 |
+
with contextlib.suppress(Exception):
|
743 |
+
return eval(x)['task']
|
744 |
+
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
|
745 |
+
with contextlib.suppress(Exception):
|
746 |
+
return cfg2task(eval(x))
|
747 |
+
|
748 |
+
for m in model.modules():
|
749 |
+
if isinstance(m, Detect):
|
750 |
+
return 'detect'
|
751 |
+
elif isinstance(m, Segment):
|
752 |
+
return 'segment'
|
753 |
+
elif isinstance(m, Classify):
|
754 |
+
return 'classify'
|
755 |
+
elif isinstance(m, Pose):
|
756 |
+
return 'pose'
|
757 |
+
|
758 |
+
# Guess from model filename
|
759 |
+
if isinstance(model, (str, Path)):
|
760 |
+
model = Path(model)
|
761 |
+
if '-seg' in model.stem or 'segment' in model.parts:
|
762 |
+
return 'segment'
|
763 |
+
elif '-cls' in model.stem or 'classify' in model.parts:
|
764 |
+
return 'classify'
|
765 |
+
elif '-pose' in model.stem or 'pose' in model.parts:
|
766 |
+
return 'pose'
|
767 |
+
elif 'detect' in model.parts:
|
768 |
+
return 'detect'
|
769 |
+
|
770 |
+
# Unable to determine task from model
|
771 |
+
LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
|
772 |
+
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify', or 'pose'.")
|
773 |
+
return 'detect' # assume detect
|
ultralytics/tracker/README.md
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Tracker
|
2 |
+
|
3 |
+
## Supported Trackers
|
4 |
+
|
5 |
+
- [x] ByteTracker
|
6 |
+
- [x] BoT-SORT
|
7 |
+
|
8 |
+
## Usage
|
9 |
+
|
10 |
+
### python interface:
|
11 |
+
|
12 |
+
You can use the Python interface to track objects using the YOLO model.
|
13 |
+
|
14 |
+
```python
|
15 |
+
from ultralytics import YOLO
|
16 |
+
|
17 |
+
model = YOLO("yolov8n.pt") # or a segmentation model .i.e yolov8n-seg.pt
|
18 |
+
model.track(
|
19 |
+
source="video/streams",
|
20 |
+
stream=True,
|
21 |
+
tracker="botsort.yaml", # or 'bytetrack.yaml'
|
22 |
+
show=True,
|
23 |
+
)
|
24 |
+
```
|
25 |
+
|
26 |
+
You can get the IDs of the tracked objects using the following code:
|
27 |
+
|
28 |
+
```python
|
29 |
+
from ultralytics import YOLO
|
30 |
+
|
31 |
+
model = YOLO("yolov8n.pt")
|
32 |
+
|
33 |
+
for result in model.track(source="video.mp4"):
|
34 |
+
print(
|
35 |
+
result.boxes.id.cpu().numpy().astype(int)
|
36 |
+
) # this will print the IDs of the tracked objects in the frame
|
37 |
+
```
|
38 |
+
|
39 |
+
If you want to use the tracker with a folder of images or when you loop on the video frames, you should use the `persist` parameter to tell the model that these frames are related to each other so the IDs will be fixed for the same objects. Otherwise, the IDs will be different in each frame because in each loop, the model creates a new object for tracking, but the `persist` parameter makes it use the same object for tracking.
|
40 |
+
|
41 |
+
```python
|
42 |
+
import cv2
|
43 |
+
from ultralytics import YOLO
|
44 |
+
|
45 |
+
cap = cv2.VideoCapture("video.mp4")
|
46 |
+
model = YOLO("yolov8n.pt")
|
47 |
+
while True:
|
48 |
+
ret, frame = cap.read()
|
49 |
+
if not ret:
|
50 |
+
break
|
51 |
+
results = model.track(frame, persist=True)
|
52 |
+
boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)
|
53 |
+
ids = results[0].boxes.id.cpu().numpy().astype(int)
|
54 |
+
for box, id in zip(boxes, ids):
|
55 |
+
cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
|
56 |
+
cv2.putText(
|
57 |
+
frame,
|
58 |
+
f"Id {id}",
|
59 |
+
(box[0], box[1]),
|
60 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
61 |
+
1,
|
62 |
+
(0, 0, 255),
|
63 |
+
2,
|
64 |
+
)
|
65 |
+
cv2.imshow("frame", frame)
|
66 |
+
if cv2.waitKey(1) & 0xFF == ord("q"):
|
67 |
+
break
|
68 |
+
```
|
69 |
+
|
70 |
+
## Change tracker parameters
|
71 |
+
|
72 |
+
You can change the tracker parameters by eding the `tracker.yaml` file which is located in the ultralytics/tracker/cfg folder.
|
73 |
+
|
74 |
+
## Command Line Interface (CLI)
|
75 |
+
|
76 |
+
You can also use the command line interface to track objects using the YOLO model.
|
77 |
+
|
78 |
+
```bash
|
79 |
+
yolo detect track source=... tracker=...
|
80 |
+
yolo segment track source=... tracker=...
|
81 |
+
yolo pose track source=... tracker=...
|
82 |
+
```
|
83 |
+
|
84 |
+
By default, trackers will use the configuration in `ultralytics/tracker/cfg`.
|
85 |
+
We also support using a modified tracker config file. Please refer to the tracker config files
|
86 |
+
in `ultralytics/tracker/cfg`.<br>
|
ultralytics/tracker/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
from .track import register_tracker
|
4 |
+
from .trackers import BOTSORT, BYTETracker
|
5 |
+
|
6 |
+
__all__ = 'register_tracker', 'BOTSORT', 'BYTETracker' # allow simpler import
|