yolov5 / utils /general.py
Zhiqiang Wang
Fix pylint: do not use bare 'except' (#5025)
1922dde unverified
raw
history blame
34.1 kB
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
General utils
"""
import contextlib
import glob
import logging
import math
import os
import platform
import random
import re
import signal
import time
import urllib
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
from subprocess import check_output
from zipfile import ZipFile
import cv2
import numpy as np
import pandas as pd
import pkg_resources as pkg
import torch
import torchvision
import yaml
from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness
# Settings
torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
pd.options.display.max_columns = 10
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLOv5 root directory
class Profile(contextlib.ContextDecorator):
# Usage: @Profile() decorator or 'with Profile():' context manager
def __enter__(self):
self.start = time.time()
def __exit__(self, type, value, traceback):
print(f'Profile results: {time.time() - self.start:.5f}s')
class Timeout(contextlib.ContextDecorator):
# Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
self.seconds = int(seconds)
self.timeout_message = timeout_msg
self.suppress = bool(suppress_timeout_errors)
def _timeout_handler(self, signum, frame):
raise TimeoutError(self.timeout_message)
def __enter__(self):
signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
def __exit__(self, exc_type, exc_val, exc_tb):
signal.alarm(0) # Cancel SIGALRM if it's scheduled
if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
return True
def try_except(func):
# try-except function. Usage: @try_except decorator
def handler(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
print(e)
return handler
def methods(instance):
# Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
def set_logging(rank=-1, verbose=True):
logging.basicConfig(
format="%(message)s",
level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
def print_args(name, opt):
# Print argparser arguments
print(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
def init_seeds(seed=0):
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
# cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
import torch.backends.cudnn as cudnn
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
def get_latest_run(search_dir='.'):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ''
def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
# Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
env = os.getenv(env_var)
if env:
path = Path(env) # use environment variable
else:
cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
path.mkdir(exist_ok=True) # make if required
return path
def is_writeable(dir, test=False):
# Return True if directory has write permissions, test opening a file with write permissions if test=True
if test: # method 1
file = Path(dir) / 'tmp.txt'
try:
with open(file, 'w'): # open file with write permissions
pass
file.unlink() # remove file
return True
except IOError:
return False
else: # method 2
return os.access(dir, os.R_OK) # possible issues on Windows
def is_docker():
# Is environment a Docker container?
return Path('/workspace').exists() # or Path('/.dockerenv').exists()
def is_colab():
# Is environment a Google Colab instance?
try:
import google.colab
return True
except ImportError:
return False
def is_pip():
# Is file in a pip package?
return 'site-packages' in Path(__file__).resolve().parts
def is_ascii(s=''):
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
s = str(s) # convert list, tuple, None, etc. to str
return len(s.encode().decode('ascii', 'ignore')) == len(s)
def is_chinese(s='人工智能'):
# Is string composed of any Chinese characters?
return re.search('[\u4e00-\u9fff]', s)
def emojis(str=''):
# Return platform-dependent emoji-safe version of string
return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
def file_size(path):
# Return file/dir size (MB)
path = Path(path)
if path.is_file():
return path.stat().st_size / 1E6
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
else:
return 0.0
def check_online():
# Check internet connectivity
import socket
try:
socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
return True
except OSError:
return False
@try_except
def check_git_status():
# Recommend 'git pull' if code is out of date
msg = ', for updates see https://github.com/ultralytics/yolov5'
print(colorstr('github: '), end='')
assert Path('.git').exists(), 'skipping check (not a git repository)' + msg
assert not is_docker(), 'skipping check (Docker image)' + msg
assert check_online(), 'skipping check (offline)' + msg
cmd = 'git fetch && git config --get remote.origin.url'
url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
if n > 0:
s = f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
else:
s = f'up to date with {url} ✅'
print(emojis(s)) # emoji-safe
def check_python(minimum='3.6.2'):
# Check current python version vs. required python version
check_version(platform.python_version(), minimum, name='Python ')
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
# Check version vs. required version
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
result = (current == minimum) if pinned else (current >= minimum)
assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
@try_except
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
# Check installed dependencies meet requirements (pass *.txt file or list of packages)
prefix = colorstr('red', 'bold', 'requirements:')
check_python() # check python version
if isinstance(requirements, (str, Path)): # requirements.txt file
file = Path(requirements)
assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
else: # list or tuple of packages
requirements = [x for x in requirements if x not in exclude]
n = 0 # number of packages updates
for r in requirements:
try:
pkg.require(r)
except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
s = f"{prefix} {r} not found and is required by YOLOv5"
if install:
print(f"{s}, attempting auto-update...")
try:
assert check_online(), f"'pip install {r}' skipped (offline)"
print(check_output(f"pip install '{r}'", shell=True).decode())
n += 1
except Exception as e:
print(f'{prefix} {e}')
else:
print(f'{s}. Please install and rerun your command.')
if n: # if packages updated
source = file.resolve() if 'file' in locals() else requirements
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
print(emojis(s))
def check_img_size(imgsz, s=32, floor=0):
# Verify image size is a multiple of stride s in each dimension
if isinstance(imgsz, int): # integer i.e. img_size=640
new_size = max(make_divisible(imgsz, int(s)), floor)
else: # list i.e. img_size=[640, 480]
new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
if new_size != imgsz:
print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
return new_size
def check_imshow():
# Check if environment supports image displays
try:
assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
cv2.imshow('test', np.zeros((1, 1, 3)))
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
return True
except Exception as e:
print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
return False
def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
# Check file(s) for acceptable suffixes
if file and suffix:
if isinstance(suffix, str):
suffix = [suffix]
for f in file if isinstance(file, (list, tuple)) else [file]:
assert Path(f).suffix.lower() in suffix, f"{msg}{f} acceptable suffix is {suffix}"
def check_yaml(file, suffix=('.yaml', '.yml')):
# Search/download YAML file (if necessary) and return path, checking suffix
return check_file(file, suffix)
def check_file(file, suffix=''):
# Search/download file (if necessary) and return path
check_suffix(file, suffix) # optional
file = str(file) # convert to str()
if Path(file).is_file() or file == '': # exists
return file
elif file.startswith(('http:/', 'https:/')): # download
url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
print(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, file)
assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
return file
else: # search
files = []
for d in 'data', 'models', 'utils': # search directories
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
assert len(files), f'File not found: {file}' # assert file was found
assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
return files[0] # return file
def check_dataset(data, autodownload=True):
# Download and/or unzip dataset if not found locally
# Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
# Download (optional)
extract_dir = ''
if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
extract_dir, autodownload = data.parent, False
# Read yaml (optional)
if isinstance(data, (str, Path)):
with open(data, errors='ignore') as f:
data = yaml.safe_load(f) # dictionary
# Parse yaml
path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
for k in 'train', 'val', 'test':
if data.get(k): # prepend path
data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
assert 'nc' in data, "Dataset 'nc' key missing."
if 'names' not in data:
data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
if val:
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
if not all(x.exists() for x in val):
print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
if s and autodownload: # download script
root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
if s.startswith('http') and s.endswith('.zip'): # URL
f = Path(s).name # filename
print(f'Downloading {s} to {f}...')
torch.hub.download_url_to_file(s, f)
Path(root).mkdir(parents=True, exist_ok=True) # create root
ZipFile(f).extractall(path=root) # unzip
Path(f).unlink() # remove zip
r = None # success
elif s.startswith('bash '): # bash script
print(f'Running {s} ...')
r = os.system(s)
else: # python script
r = exec(s, {'yaml': data}) # return None
print(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}\n")
else:
raise Exception('Dataset not found.')
return data # dictionary
def url2file(url):
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
return file
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
# Multi-threaded file download and unzip function, used in data.yaml for autodownload
def download_one(url, dir):
# Download 1 file
f = dir / Path(url).name # filename
if Path(url).is_file(): # exists in current path
Path(url).rename(f) # move to dir
elif not f.exists():
print(f'Downloading {url} to {f}...')
if curl:
os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
else:
torch.hub.download_url_to_file(url, f, progress=True) # torch download
if unzip and f.suffix in ('.zip', '.gz'):
print(f'Unzipping {f}...')
if f.suffix == '.zip':
ZipFile(f).extractall(path=dir) # unzip
elif f.suffix == '.gz':
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
if delete:
f.unlink() # remove zip
dir = Path(dir)
dir.mkdir(parents=True, exist_ok=True) # make directory
if threads > 1:
pool = ThreadPool(threads)
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
pool.close()
pool.join()
else:
for u in [url] if isinstance(url, (str, Path)) else url:
download_one(u, dir)
def make_divisible(x, divisor):
# Returns x evenly divisible by divisor
return math.ceil(x / divisor) * divisor
def clean_str(s):
# Cleans a string by replacing special characters with underscore _
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
def one_cycle(y1=0.0, y2=1.0, steps=100):
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
colors = {'black': '\033[30m', # basic colors
'red': '\033[31m',
'green': '\033[32m',
'yellow': '\033[33m',
'blue': '\033[34m',
'magenta': '\033[35m',
'cyan': '\033[36m',
'white': '\033[37m',
'bright_black': '\033[90m', # bright colors
'bright_red': '\033[91m',
'bright_green': '\033[92m',
'bright_yellow': '\033[93m',
'bright_blue': '\033[94m',
'bright_magenta': '\033[95m',
'bright_cyan': '\033[96m',
'bright_white': '\033[97m',
'end': '\033[0m', # misc
'bold': '\033[1m',
'underline': '\033[4m'}
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels
if labels[0] is None: # no labels loaded
return torch.Tensor()
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
classes = labels[:, 0].astype(np.int) # labels = [class xywh]
weights = np.bincount(classes, minlength=nc) # occurrences per class
# Prepend gridpoint count (for uCE training)
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
# weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
weights[weights == 0] = 1 # replace empty bins with 1
weights = 1 / weights # number of targets per class
weights /= weights.sum() # normalize
return torch.from_numpy(weights)
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
# Produces image weights based on class_weights and image contents
class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
# index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
return image_weights
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
return x
def xyxy2xywh(x):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
y[:, 2] = x[:, 2] - x[:, 0] # width
y[:, 3] = x[:, 3] - x[:, 1] # height
return y
def xywh2xyxy(x):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
return y
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
# Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
return y
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
if clip:
clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
return y
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
# Convert normalized segments into pixel segments, shape (n,2)
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = w * x[:, 0] + padw # top left x
y[:, 1] = h * x[:, 1] + padh # top left y
return y
def segment2box(segment, width=640, height=640):
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
x, y = segment.T # segment xy
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
x, y, = x[inside], y[inside]
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
def segments2boxes(segments):
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
boxes = []
for s in segments:
x, y = s.T # segment xy
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
return xyxy2xywh(np.array(boxes)) # cls, xywh
def resample_segments(segments, n=1000):
# Up-sample an (n,2) segment
for i, s in enumerate(segments):
x = np.linspace(0, len(s) - 1, n)
xp = np.arange(len(s))
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
return segments
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
# Rescale coords (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
coords[:, [0, 2]] -= pad[0] # x padding
coords[:, [1, 3]] -= pad[1] # y padding
coords[:, :4] /= gain
clip_coords(coords, img0_shape)
return coords
def clip_coords(boxes, shape):
# Clip bounding xyxy bounding boxes to image shape (height, width)
if isinstance(boxes, torch.Tensor): # faster individually
boxes[:, 0].clamp_(0, shape[1]) # x1
boxes[:, 1].clamp_(0, shape[0]) # y1
boxes[:, 2].clamp_(0, shape[1]) # x2
boxes[:, 3].clamp_(0, shape[0]) # y2
else: # np.array (faster grouped)
boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=(), max_det=300):
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
nc = prediction.shape[2] - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
# Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
t = time.time()
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 5), device=x.device)
v[:, :4] = l[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if (time.time() - t) > time_limit:
print(f'WARNING: NMS time limit {time_limit}s exceeded')
break # time limit exceeded
return output
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
x[k] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = False
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
def print_mutation(results, hyp, save_dir, bucket):
evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
keys = tuple(x.strip() for x in keys)
vals = results + tuple(hyp.values())
n = len(keys)
# Download (optional)
if bucket:
url = f'gs://{bucket}/evolve.csv'
if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
# Log to evolve.csv
s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
with open(evolve_csv, 'a') as f:
f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
# Print to screen
print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')
# Save yaml
with open(evolve_yaml, 'w') as f:
data = pd.read_csv(evolve_csv)
data = data.rename(columns=lambda x: x.strip()) # strip keys
i = np.argmax(fitness(data.values[:, :7])) #
f.write('# YOLOv5 Hyperparameter Evolution Results\n' +
f'# Best generation: {i}\n' +
f'# Last generation: {len(data)}\n' +
'# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
'# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
yaml.safe_dump(hyp, f, sort_keys=False)
if bucket:
os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
def apply_classifier(x, model, img, im0):
# Apply a second stage classifier to yolo outputs
im0 = [im0] if isinstance(im0, np.ndarray) else im0
for i, d in enumerate(x): # per image
if d is not None and len(d):
d = d.clone()
# Reshape and pad cutouts
b = xyxy2xywh(d[:, :4]) # boxes
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
d[:, :4] = xywh2xyxy(b).long()
# Rescale boxes from img_size to im0 size
scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
# Classes
pred_cls1 = d[:, 5].long()
ims = []
for j, a in enumerate(d): # per item
cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
im = cv2.resize(cutout, (224, 224)) # BGR
# cv2.imwrite('example%i.jpg' % j, cutout)
im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
im /= 255.0 # 0 - 255 to 0.0 - 1.0
ims.append(im)
pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
return x
def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
xyxy = torch.tensor(xyxy).view(-1, 4)
b = xyxy2xywh(xyxy) # boxes
if square:
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
xyxy = xywh2xyxy(b).long()
clip_coords(xyxy, im.shape)
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
if save:
cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
return crop
def increment_path(path, exist_ok=False, sep='', mkdir=False):
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
suffix = path.suffix
path = path.with_suffix('')
dirs = glob.glob(f"{path}{sep}*") # similar paths
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
i = [int(m.groups()[0]) for m in matches if m] # indices
n = max(i) + 1 if i else 2 # increment number
path = Path(f"{path}{sep}{n}{suffix}") # update path
dir = path if path.suffix == '' else path.parent # directory
if not dir.exists() and mkdir:
dir.mkdir(parents=True, exist_ok=True) # make directory
return path