glenn-jocher
commited on
Commit
•
33202b7
1
Parent(s):
6a3ee7c
YOLOv5 + Albumentations integration (#3882)
Browse files* Albumentations integration
* ToGray p=0.01
* print confirmation
* create instance in dataloader init method
* improved version handling
* transform not defined fix
* assert string update
* create check_version()
* add spaces
* update class comment
- requirements.txt +1 -0
- utils/augmentations.py +29 -1
- utils/datasets.py +21 -19
- utils/general.py +10 -7
requirements.txt
CHANGED
@@ -27,4 +27,5 @@ pandas
|
|
27 |
# extras --------------------------------------
|
28 |
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
|
29 |
# pycocotools>=2.0 # COCO mAP
|
|
|
30 |
thop # FLOPs computation
|
|
|
27 |
# extras --------------------------------------
|
28 |
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
|
29 |
# pycocotools>=2.0 # COCO mAP
|
30 |
+
# albumentations>=1.0.0
|
31 |
thop # FLOPs computation
|
utils/augmentations.py
CHANGED
@@ -1,15 +1,43 @@
|
|
1 |
# YOLOv5 image augmentation functions
|
2 |
|
|
|
3 |
import random
|
4 |
|
5 |
import cv2
|
6 |
import math
|
7 |
import numpy as np
|
8 |
|
9 |
-
from utils.general import segment2box, resample_segments
|
10 |
from utils.metrics import bbox_ioa
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
|
14 |
# HSV color-space augmentation
|
15 |
if hgain or sgain or vgain:
|
|
|
1 |
# YOLOv5 image augmentation functions
|
2 |
|
3 |
+
import logging
|
4 |
import random
|
5 |
|
6 |
import cv2
|
7 |
import math
|
8 |
import numpy as np
|
9 |
|
10 |
+
from utils.general import colorstr, segment2box, resample_segments, check_version
|
11 |
from utils.metrics import bbox_ioa
|
12 |
|
13 |
|
14 |
+
class Albumentations:
|
15 |
+
# YOLOv5 Albumentations class (optional, only used if package is installed)
|
16 |
+
def __init__(self):
|
17 |
+
self.transform = None
|
18 |
+
try:
|
19 |
+
import albumentations as A
|
20 |
+
check_version(A.__version__, '1.0.0') # version requirement
|
21 |
+
|
22 |
+
self.transform = A.Compose([
|
23 |
+
A.Blur(p=0.1),
|
24 |
+
A.MedianBlur(p=0.1),
|
25 |
+
A.ToGray(p=0.01)],
|
26 |
+
bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
|
27 |
+
|
28 |
+
logging.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms))
|
29 |
+
except ImportError: # package not installed, skip
|
30 |
+
pass
|
31 |
+
except Exception as e:
|
32 |
+
logging.info(colorstr('albumentations: ') + f'{e}')
|
33 |
+
|
34 |
+
def __call__(self, im, labels, p=1.0):
|
35 |
+
if self.transform and random.random() < p:
|
36 |
+
new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
|
37 |
+
im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
|
38 |
+
return im, labels
|
39 |
+
|
40 |
+
|
41 |
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
|
42 |
# HSV color-space augmentation
|
43 |
if hgain or sgain or vgain:
|
utils/datasets.py
CHANGED
@@ -22,7 +22,7 @@ from PIL import Image, ExifTags
|
|
22 |
from torch.utils.data import Dataset
|
23 |
from tqdm import tqdm
|
24 |
|
25 |
-
from utils.augmentations import augment_hsv, copy_paste, letterbox, mixup, random_perspective
|
26 |
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
|
27 |
xyn2xy, segments2boxes, clean_str
|
28 |
from utils.torch_utils import torch_distributed_zero_first
|
@@ -372,6 +372,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
372 |
self.mosaic_border = [-img_size // 2, -img_size // 2]
|
373 |
self.stride = stride
|
374 |
self.path = path
|
|
|
375 |
|
376 |
try:
|
377 |
f = [] # image files
|
@@ -539,9 +540,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
539 |
if labels.size: # normalized xywh to pixel xyxy format
|
540 |
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
|
541 |
|
542 |
-
|
543 |
-
# Augment imagespace
|
544 |
-
if not mosaic:
|
545 |
img, labels = random_perspective(img, labels,
|
546 |
degrees=hyp['degrees'],
|
547 |
translate=hyp['translate'],
|
@@ -549,32 +548,35 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
549 |
shear=hyp['shear'],
|
550 |
perspective=hyp['perspective'])
|
551 |
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
# Apply cutouts
|
556 |
-
# if random.random() < 0.9:
|
557 |
-
# labels = cutout(img, labels)
|
558 |
-
|
559 |
-
nL = len(labels) # number of labels
|
560 |
-
if nL:
|
561 |
labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized
|
562 |
|
563 |
if self.augment:
|
564 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
if random.random() < hyp['flipud']:
|
566 |
img = np.flipud(img)
|
567 |
-
if
|
568 |
labels[:, 2] = 1 - labels[:, 2]
|
569 |
|
570 |
-
#
|
571 |
if random.random() < hyp['fliplr']:
|
572 |
img = np.fliplr(img)
|
573 |
-
if
|
574 |
labels[:, 1] = 1 - labels[:, 1]
|
575 |
|
576 |
-
|
577 |
-
|
|
|
|
|
|
|
|
|
578 |
labels_out[:, 1:] = torch.from_numpy(labels)
|
579 |
|
580 |
# Convert
|
|
|
22 |
from torch.utils.data import Dataset
|
23 |
from tqdm import tqdm
|
24 |
|
25 |
+
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
|
26 |
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
|
27 |
xyn2xy, segments2boxes, clean_str
|
28 |
from utils.torch_utils import torch_distributed_zero_first
|
|
|
372 |
self.mosaic_border = [-img_size // 2, -img_size // 2]
|
373 |
self.stride = stride
|
374 |
self.path = path
|
375 |
+
self.albumentations = Albumentations() if augment else None
|
376 |
|
377 |
try:
|
378 |
f = [] # image files
|
|
|
540 |
if labels.size: # normalized xywh to pixel xyxy format
|
541 |
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
|
542 |
|
543 |
+
if self.augment:
|
|
|
|
|
544 |
img, labels = random_perspective(img, labels,
|
545 |
degrees=hyp['degrees'],
|
546 |
translate=hyp['translate'],
|
|
|
548 |
shear=hyp['shear'],
|
549 |
perspective=hyp['perspective'])
|
550 |
|
551 |
+
nl = len(labels) # number of labels
|
552 |
+
if nl:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized
|
554 |
|
555 |
if self.augment:
|
556 |
+
# Albumentations
|
557 |
+
img, labels = self.albumentations(img, labels)
|
558 |
+
|
559 |
+
# HSV color-space
|
560 |
+
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
|
561 |
+
|
562 |
+
# Flip up-down
|
563 |
if random.random() < hyp['flipud']:
|
564 |
img = np.flipud(img)
|
565 |
+
if nl:
|
566 |
labels[:, 2] = 1 - labels[:, 2]
|
567 |
|
568 |
+
# Flip left-right
|
569 |
if random.random() < hyp['fliplr']:
|
570 |
img = np.fliplr(img)
|
571 |
+
if nl:
|
572 |
labels[:, 1] = 1 - labels[:, 1]
|
573 |
|
574 |
+
# Cutouts
|
575 |
+
# if random.random() < 0.9:
|
576 |
+
# labels = cutout(img, labels)
|
577 |
+
|
578 |
+
labels_out = torch.zeros((nl, 6))
|
579 |
+
if nl:
|
580 |
labels_out[:, 1:] = torch.from_numpy(labels)
|
581 |
|
582 |
# Convert
|
utils/general.py
CHANGED
@@ -3,7 +3,6 @@
|
|
3 |
import contextlib
|
4 |
import glob
|
5 |
import logging
|
6 |
-
import math
|
7 |
import os
|
8 |
import platform
|
9 |
import random
|
@@ -17,6 +16,7 @@ from pathlib import Path
|
|
17 |
from subprocess import check_output
|
18 |
|
19 |
import cv2
|
|
|
20 |
import numpy as np
|
21 |
import pandas as pd
|
22 |
import pkg_resources as pkg
|
@@ -136,13 +136,16 @@ def check_git_status(err_msg=', for updates see https://github.com/ultralytics/y
|
|
136 |
print(f'{e}{err_msg}')
|
137 |
|
138 |
|
139 |
-
def check_python(minimum='3.6.2'
|
140 |
# Check current python version vs. required python version
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
146 |
|
147 |
|
148 |
def check_requirements(requirements='requirements.txt', exclude=()):
|
|
|
3 |
import contextlib
|
4 |
import glob
|
5 |
import logging
|
|
|
6 |
import os
|
7 |
import platform
|
8 |
import random
|
|
|
16 |
from subprocess import check_output
|
17 |
|
18 |
import cv2
|
19 |
+
import math
|
20 |
import numpy as np
|
21 |
import pandas as pd
|
22 |
import pkg_resources as pkg
|
|
|
136 |
print(f'{e}{err_msg}')
|
137 |
|
138 |
|
139 |
+
def check_python(minimum='3.6.2'):
|
140 |
# Check current python version vs. required python version
|
141 |
+
check_version(platform.python_version(), minimum, name='Python ')
|
142 |
+
|
143 |
+
|
144 |
+
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
|
145 |
+
# Check version vs. required version
|
146 |
+
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
147 |
+
result = (current == minimum) if pinned else (current >= minimum)
|
148 |
+
assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
|
149 |
|
150 |
|
151 |
def check_requirements(requirements='requirements.txt', exclude=()):
|