glenn-jocher commited on
Commit
33202b7
1 Parent(s): 6a3ee7c

YOLOv5 + Albumentations integration (#3882)

Browse files

* Albumentations integration

* ToGray p=0.01

* print confirmation

* create instance in dataloader init method

* improved version handling

* transform not defined fix

* assert string update

* create check_version()

* add spaces

* update class comment

Files changed (4) hide show
  1. requirements.txt +1 -0
  2. utils/augmentations.py +29 -1
  3. utils/datasets.py +21 -19
  4. utils/general.py +10 -7
requirements.txt CHANGED
@@ -27,4 +27,5 @@ pandas
27
  # extras --------------------------------------
28
  # Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
29
  # pycocotools>=2.0 # COCO mAP
 
30
  thop # FLOPs computation
 
27
  # extras --------------------------------------
28
  # Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
29
  # pycocotools>=2.0 # COCO mAP
30
+ # albumentations>=1.0.0
31
  thop # FLOPs computation
utils/augmentations.py CHANGED
@@ -1,15 +1,43 @@
1
  # YOLOv5 image augmentation functions
2
 
 
3
  import random
4
 
5
  import cv2
6
  import math
7
  import numpy as np
8
 
9
- from utils.general import segment2box, resample_segments
10
  from utils.metrics import bbox_ioa
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
14
  # HSV color-space augmentation
15
  if hgain or sgain or vgain:
 
1
  # YOLOv5 image augmentation functions
2
 
3
+ import logging
4
  import random
5
 
6
  import cv2
7
  import math
8
  import numpy as np
9
 
10
+ from utils.general import colorstr, segment2box, resample_segments, check_version
11
  from utils.metrics import bbox_ioa
12
 
13
 
14
+ class Albumentations:
15
+ # YOLOv5 Albumentations class (optional, only used if package is installed)
16
+ def __init__(self):
17
+ self.transform = None
18
+ try:
19
+ import albumentations as A
20
+ check_version(A.__version__, '1.0.0') # version requirement
21
+
22
+ self.transform = A.Compose([
23
+ A.Blur(p=0.1),
24
+ A.MedianBlur(p=0.1),
25
+ A.ToGray(p=0.01)],
26
+ bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
27
+
28
+ logging.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms))
29
+ except ImportError: # package not installed, skip
30
+ pass
31
+ except Exception as e:
32
+ logging.info(colorstr('albumentations: ') + f'{e}')
33
+
34
+ def __call__(self, im, labels, p=1.0):
35
+ if self.transform and random.random() < p:
36
+ new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
37
+ im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
38
+ return im, labels
39
+
40
+
41
  def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
42
  # HSV color-space augmentation
43
  if hgain or sgain or vgain:
utils/datasets.py CHANGED
@@ -22,7 +22,7 @@ from PIL import Image, ExifTags
22
  from torch.utils.data import Dataset
23
  from tqdm import tqdm
24
 
25
- from utils.augmentations import augment_hsv, copy_paste, letterbox, mixup, random_perspective
26
  from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
27
  xyn2xy, segments2boxes, clean_str
28
  from utils.torch_utils import torch_distributed_zero_first
@@ -372,6 +372,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
372
  self.mosaic_border = [-img_size // 2, -img_size // 2]
373
  self.stride = stride
374
  self.path = path
 
375
 
376
  try:
377
  f = [] # image files
@@ -539,9 +540,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
539
  if labels.size: # normalized xywh to pixel xyxy format
540
  labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
541
 
542
- if self.augment:
543
- # Augment imagespace
544
- if not mosaic:
545
  img, labels = random_perspective(img, labels,
546
  degrees=hyp['degrees'],
547
  translate=hyp['translate'],
@@ -549,32 +548,35 @@ class LoadImagesAndLabels(Dataset): # for training/testing
549
  shear=hyp['shear'],
550
  perspective=hyp['perspective'])
551
 
552
- # Augment colorspace
553
- augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
554
-
555
- # Apply cutouts
556
- # if random.random() < 0.9:
557
- # labels = cutout(img, labels)
558
-
559
- nL = len(labels) # number of labels
560
- if nL:
561
  labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized
562
 
563
  if self.augment:
564
- # flip up-down
 
 
 
 
 
 
565
  if random.random() < hyp['flipud']:
566
  img = np.flipud(img)
567
- if nL:
568
  labels[:, 2] = 1 - labels[:, 2]
569
 
570
- # flip left-right
571
  if random.random() < hyp['fliplr']:
572
  img = np.fliplr(img)
573
- if nL:
574
  labels[:, 1] = 1 - labels[:, 1]
575
 
576
- labels_out = torch.zeros((nL, 6))
577
- if nL:
 
 
 
 
578
  labels_out[:, 1:] = torch.from_numpy(labels)
579
 
580
  # Convert
 
22
  from torch.utils.data import Dataset
23
  from tqdm import tqdm
24
 
25
+ from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
26
  from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
27
  xyn2xy, segments2boxes, clean_str
28
  from utils.torch_utils import torch_distributed_zero_first
 
372
  self.mosaic_border = [-img_size // 2, -img_size // 2]
373
  self.stride = stride
374
  self.path = path
375
+ self.albumentations = Albumentations() if augment else None
376
 
377
  try:
378
  f = [] # image files
 
540
  if labels.size: # normalized xywh to pixel xyxy format
541
  labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
542
 
543
+ if self.augment:
 
 
544
  img, labels = random_perspective(img, labels,
545
  degrees=hyp['degrees'],
546
  translate=hyp['translate'],
 
548
  shear=hyp['shear'],
549
  perspective=hyp['perspective'])
550
 
551
+ nl = len(labels) # number of labels
552
+ if nl:
 
 
 
 
 
 
 
553
  labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized
554
 
555
  if self.augment:
556
+ # Albumentations
557
+ img, labels = self.albumentations(img, labels)
558
+
559
+ # HSV color-space
560
+ augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
561
+
562
+ # Flip up-down
563
  if random.random() < hyp['flipud']:
564
  img = np.flipud(img)
565
+ if nl:
566
  labels[:, 2] = 1 - labels[:, 2]
567
 
568
+ # Flip left-right
569
  if random.random() < hyp['fliplr']:
570
  img = np.fliplr(img)
571
+ if nl:
572
  labels[:, 1] = 1 - labels[:, 1]
573
 
574
+ # Cutouts
575
+ # if random.random() < 0.9:
576
+ # labels = cutout(img, labels)
577
+
578
+ labels_out = torch.zeros((nl, 6))
579
+ if nl:
580
  labels_out[:, 1:] = torch.from_numpy(labels)
581
 
582
  # Convert
utils/general.py CHANGED
@@ -3,7 +3,6 @@
3
  import contextlib
4
  import glob
5
  import logging
6
- import math
7
  import os
8
  import platform
9
  import random
@@ -17,6 +16,7 @@ from pathlib import Path
17
  from subprocess import check_output
18
 
19
  import cv2
 
20
  import numpy as np
21
  import pandas as pd
22
  import pkg_resources as pkg
@@ -136,13 +136,16 @@ def check_git_status(err_msg=', for updates see https://github.com/ultralytics/y
136
  print(f'{e}{err_msg}')
137
 
138
 
139
- def check_python(minimum='3.6.2', required=True):
140
  # Check current python version vs. required python version
141
- current = platform.python_version()
142
- result = pkg.parse_version(current) >= pkg.parse_version(minimum)
143
- if required:
144
- assert result, f'Python {minimum} required by YOLOv5, but Python {current} is currently installed'
145
- return result
 
 
 
146
 
147
 
148
  def check_requirements(requirements='requirements.txt', exclude=()):
 
3
  import contextlib
4
  import glob
5
  import logging
 
6
  import os
7
  import platform
8
  import random
 
16
  from subprocess import check_output
17
 
18
  import cv2
19
+ import math
20
  import numpy as np
21
  import pandas as pd
22
  import pkg_resources as pkg
 
136
  print(f'{e}{err_msg}')
137
 
138
 
139
+ def check_python(minimum='3.6.2'):
140
  # Check current python version vs. required python version
141
+ check_version(platform.python_version(), minimum, name='Python ')
142
+
143
+
144
+ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
145
+ # Check version vs. required version
146
+ current, minimum = (pkg.parse_version(x) for x in (current, minimum))
147
+ result = (current == minimum) if pinned else (current >= minimum)
148
+ assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
149
 
150
 
151
  def check_requirements(requirements='requirements.txt', exclude=()):