yolov8m / utils.py
zhengrongzhang's picture
init model
e6c79f4
import threading
import os
import contextlib
import torch
import torch.nn as nn
from PIL import Image, ImageDraw, ImageFont, ExifTags
from PIL import __version__ as pil_version
from multiprocessing.pool import ThreadPool
import numpy as np
from itertools import repeat
import glob
import cv2
import tempfile
import hashlib
from pathlib import Path
import time
import torchvision
import math
import re
from typing import List, Union, Dict
import pkg_resources as pkg
from types import SimpleNamespace
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import random
import yaml
import logging.config
import sys
import pathlib
CURRENT_DIR = pathlib.Path(__file__).parent
sys.path.append(str(CURRENT_DIR))
LOGGING_NAME = 'ultralytics'
LOGGER = logging.getLogger(LOGGING_NAME)
for fn in LOGGER.info, LOGGER.warning:
setattr(LOGGER, fn.__name__, lambda x: fn(x))
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
NUM_THREADS = min(8, os.cpu_count())
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
_formats = ["xyxy", "xywh", "ltwh"]
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'}
CFG_FRACTION_KEYS = {
'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'fl_gamma',
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic',
'mixup', 'copy_paste', 'conf', 'iou'}
CFG_INT_KEYS = {
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
'line_thickness', 'workspace', 'nbs'}
CFG_BOOL_KEYS = {
'save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr',
'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf',
'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader'}
# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == 'Orientation':
break
def segments2boxes(segments):
"""
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
Args:
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
Returns:
(np.ndarray): the xywh coordinates of the bounding boxes.
"""
boxes = []
for s in segments:
x, y = s.T # segment xy
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
return xyxy2xywh(np.array(boxes)) # cls, xywh
def check_version(
current: str = "0.0.0",
minimum: str = "0.0.0",
name: str = "version ",
pinned: bool = False,
hard: bool = False,
verbose: bool = False,
) -> bool:
"""
Check current version against the required minimum version.
Args:
current (str): Current version.
minimum (str): Required minimum version.
name (str): Name to be used in warning message.
pinned (bool): If True, versions must match exactly. If False, minimum version must be satisfied.
hard (bool): If True, raise an AssertionError if the minimum version is not met.
verbose (bool): If True, print warning message if minimum version is not met.
Returns:
bool: True if minimum version is met, False otherwise.
"""
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
result = (current == minimum) if pinned else (current >= minimum) # bool
warning_message = f"WARNING ⚠️ {name}{minimum} is required by YOLOv8, but {name}{current} is currently installed"
if verbose and not result:
LOGGER.warning(warning_message)
return result
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
def smart_inference_mode():
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
def decorate(fn):
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
return decorate
def box_iou(box1, box2, eps=1e-7):
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
box1 (Tensor[N, 4])
box2 (Tensor[M, 4])
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
# IoU = inter / (area1 + area2 - inter)
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
class LoadImages:
# YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
def __init__(
self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1
):
# *.txt file with img/vid/dir on each line
if isinstance(path, str) and Path(path).suffix == ".txt":
path = Path(path).read_text().rsplit()
files = []
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
p = str(Path(p).resolve())
if "*" in p:
files.extend(sorted(glob.glob(p, recursive=True))) # glob
elif os.path.isdir(p):
files.extend(sorted(glob.glob(os.path.join(p, "*.*")))) # dir
elif os.path.isfile(p):
files.append(p) # files
else:
raise FileNotFoundError(f"{p} does not exist")
# include image suffixes
images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
ni, nv = len(images), len(videos)
self.imgsz = imgsz
self.stride = stride
self.files = images + videos
self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv
self.mode = "image"
self.auto = auto
self.transforms = transforms # optional
self.vid_stride = vid_stride # video frame-rate stride
self.bs = 1
if any(videos):
self.orientation = None # rotation degrees
self._new_video(videos[0]) # new video
else:
self.cap = None
if self.nf == 0:
raise FileNotFoundError(
f"No images or videos found in {p}. "
f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
)
def __iter__(self):
self.count = 0
return self
def __next__(self):
if self.count == self.nf:
raise StopIteration
path = self.files[self.count]
if self.video_flag[self.count]:
# Read video
self.mode = "video"
for _ in range(self.vid_stride):
self.cap.grab()
success, im0 = self.cap.retrieve()
while not success:
self.count += 1
self.cap.release()
if self.count == self.nf: # last video
raise StopIteration
path = self.files[self.count]
self._new_video(path)
success, im0 = self.cap.read()
self.frame += 1
s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: "
else:
# Read image
self.count += 1
im0 = cv2.imread(path) # BGR
if im0 is None:
raise FileNotFoundError(f"Image Not Found {path}")
s = f"image {self.count}/{self.nf} {path}: "
if self.transforms:
im = self.transforms(im0) # transforms
else:
im = LetterBox(self.imgsz, self.auto, stride=self.stride)(image=im0)
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
return path, im, im0, self.cap, s
def _new_video(self, path):
# Create a new video capture object
self.frame = 0
self.cap = cv2.VideoCapture(path)
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
if hasattr(cv2, "CAP_PROP_ORIENTATION_META"): # cv2<4.6.0 compatibility
self.orientation = int(
self.cap.get(cv2.CAP_PROP_ORIENTATION_META)
) # rotation degrees
# Disable auto-orientation due to known issues in https://github.com/ultralytics/yolov5/issues/8493
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)
def _cv2_rotate(self, im):
# Rotate a cv2 video manually
if self.orientation == 0:
return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
elif self.orientation == 180:
return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
elif self.orientation == 90:
return cv2.rotate(im, cv2.ROTATE_180)
return im
def __len__(self):
return self.nf # number of files
class LetterBox:
"""Resize image and padding for detection, instance segmentation, pose"""
def __init__(
self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32
):
self.new_shape = new_shape
self.auto = auto
self.scaleFill = scaleFill
self.scaleup = scaleup
self.stride = stride
def __call__(self, labels=None, image=None):
if labels is None:
labels = {}
img = labels.get("img") if image is None else image
shape = img.shape[:2] # current shape [height, width]
new_shape = labels.pop("rect_shape", self.new_shape)
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
# only scale down, do not scale up (for better val mAP)
if not self.scaleup:
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if self.auto: # minimum rectangle
dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
elif self.scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = (
new_shape[1] / shape[1],
new_shape[0] / shape[0],
) # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if labels.get("ratio_pad"):
labels["ratio_pad"] = (labels["ratio_pad"], (dw, dh)) # for evaluation
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
) # add border
if len(labels):
labels = self._update_labels(labels, ratio, dw, dh)
labels["img"] = img
labels["resized_shape"] = new_shape
return labels
else:
return img
def _update_labels(self, labels, ratio, padw, padh):
"""Update labels"""
labels["instances"].convert_bbox(format="xyxy")
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
labels["instances"].scale(*ratio)
labels["instances"].add_padding(padw, padh)
return labels
class Annotator:
# YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
def __init__(
self,
im,
line_width=None,
font_size=None,
font="Arial.ttf",
pil=False,
example="abc",
):
assert (
im.data.contiguous
), "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images."
# non-latin labels, i.e. asian, arabic, cyrillic
non_ascii = not is_ascii(example)
self.pil = pil or non_ascii
if self.pil: # use PIL
self.pil_9_2_0_check = check_version(
pil_version, "9.2.0"
) # deprecation check
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
self.draw = ImageDraw.Draw(self.im)
self.font = ImageFont.load_default()
else: # use cv2
self.im = im
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
def box_label(
self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255)
):
# Add one xyxy box to image with label
if isinstance(box, torch.Tensor):
box = box.tolist()
if self.pil or not is_ascii(label):
self.draw.rectangle(box, width=self.lw, outline=color) # box
if label:
if self.pil_9_2_0_check:
_, _, w, h = self.font.getbbox(label) # text width, height (New)
else:
w, h = self.font.getsize(
label
) # text width, height (Old, deprecated in 9.2.0)
outside = box[1] - h >= 0 # label fits outside box
self.draw.rectangle(
(
box[0],
box[1] - h if outside else box[1],
box[0] + w + 1,
box[1] + 1 if outside else box[1] + h + 1,
),
fill=color,
)
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
self.draw.text(
(box[0], box[1] - h if outside else box[1]),
label,
fill=txt_color,
font=self.font,
)
else: # cv2
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(
self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA
)
if label:
tf = max(self.lw - 1, 1) # font thickness
# text width, height
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0]
outside = p1[1] - h >= 3
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
cv2.putText(
self.im,
label,
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
0,
self.lw / 3,
txt_color,
thickness=tf,
lineType=cv2.LINE_AA,
)
def rectangle(self, xy, fill=None, outline=None, width=1):
# Add rectangle to image (PIL-only)
self.draw.rectangle(xy, fill, outline, width)
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top"):
# Add text to image (PIL-only)
if anchor == "bottom": # start y from font bottom
w, h = self.font.getsize(text) # text width, height
xy[1] += 1 - h
self.draw.text(xy, text, fill=txt_color, font=self.font)
def fromarray(self, im):
# Update self.im from a numpy array
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
self.draw = ImageDraw.Draw(self.im)
def result(self):
# Return annotated image as array
return np.asarray(self.im)
def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nm=0, # number of masks
):
# Checks
assert (
0 <= conf_thres <= 1
), f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
assert (
0 <= iou_thres <= 1
), f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
# YOLOv8 model in validation model, output = (inference_out, loss_out)
if isinstance(prediction, (list, tuple)):
prediction = prediction[0] # select only inference output
device = prediction.device
mps = "mps" in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
nc = prediction.shape[1] - nm - 4 # number of classes
mi = 4 + nc # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
# Settings
# min_wh = 2 # (pixels) minimum box width and height
max_wh = 7680 # (pixels) maximum box width and height
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 0.5 + 0.05 * bs # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
t = time.time()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x.transpose(0, -1)[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
v[:, :4] = lb[:, 1:5] # box
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Detections matrix nx6 (xyxy, conf, cls)
box, cls, mask = x.split((4, nc, nm), 1)
# center_x, center_y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(box)
if multi_label:
i, j = (cls > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
# sort by confidence and remove excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]]
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
# boxes (offset by class), scores
boxes, scores = x[:, :4] + c, x[:, 4]
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(
1, keepdim=True
) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if mps:
output[xi] = output[xi].to(device)
if (time.time() - t) > time_limit:
LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
break # time limit exceeded
return output
class Colors:
# Ultralytics color palette https://ultralytics.com/
def __init__(self):
# hex = matplotlib.colors.TABLEAU_COLORS.values()
hexs = (
"FF3838",
"FF9D97",
"FF701F",
"FFB21D",
"CFD231",
"48F90A",
"92CC17",
"3DDB86",
"1A9334",
"00D4BB",
"2C99A8",
"00C2FF",
"344593",
"6473FF",
"0018EC",
"8438FF",
"520085",
"CB38FF",
"FF95C8",
"FF37C7",
)
self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
self.n = len(self.palette)
def __call__(self, i, bgr=False):
c = self.palette[int(i) % self.n]
return (c[2], c[1], c[0]) if bgr else c
@staticmethod
def hex2rgb(h): # rgb order (PIL)
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
colors = Colors() # create instance for 'from utils.plots import colors'
def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
return wrapper
def plot_images(
images,
batch_idx,
cls,
bboxes,
masks=np.zeros(0, dtype=np.uint8),
paths=None,
fname="images.jpg",
names=None,
):
# Plot image grid with labels
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
if isinstance(cls, torch.Tensor):
cls = cls.cpu().numpy()
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.cpu().numpy()
if isinstance(masks, torch.Tensor):
masks = masks.cpu().numpy().astype(int)
if isinstance(batch_idx, torch.Tensor):
batch_idx = batch_idx.cpu().numpy()
max_size = 1920 # max image size
max_subplots = 16 # max image subplots, i.e. 4x4
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images
ns = np.ceil(bs**0.5) # number of subplots (square)
if np.max(images[0]) <= 1:
images *= 255 # de-normalise (optional)
# Build Image
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i, im in enumerate(images):
if i == max_subplots: # if last batch has fewer images than we expect
break
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
im = im.transpose(1, 2, 0)
mosaic[y : y + h, x : x + w, :] = im
# Resize (optional)
scale = max_size / ns / max(h, w)
if scale < 1:
h = math.ceil(scale * h)
w = math.ceil(scale * w)
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
# Annotate
fs = int((h + w) * ns * 0.01) # font size
annotator = Annotator(
mosaic, line_width=2, font_size=fs, pil=True, example=names
)
for i in range(i + 1):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
annotator.rectangle(
[x, y, x + w, y + h], None, (255, 255, 255), width=2
) # borders
if paths:
annotator.text(
# filenames
(x + 5, y + 5 + h),
text=Path(paths[i]).name[:40],
txt_color=(220, 220, 220),
)
if len(cls) > 0:
idx = batch_idx == i
boxes = xywh2xyxy(bboxes[idx, :4]).T
classes = cls[idx].astype("int")
labels = bboxes.shape[1] == 4 # labels if no conf column
# check for confidence presence (label vs pred)
conf = None if labels else bboxes[idx, 4]
if boxes.shape[1]:
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
boxes[[0, 2]] *= w # scale to pixels
boxes[[1, 3]] *= h
elif scale < 1: # absolute coords need scale if image scales
boxes *= scale
boxes[[0, 2]] += x
boxes[[1, 3]] += y
for j, box in enumerate(boxes.T.tolist()):
c = classes[j]
color = colors(c)
c = names[c] if names else c
if labels or conf[j] > 0.25: # 0.25 conf thresh
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
annotator.box_label(box, label, color=color)
annotator.im.save(fname) # save
def output_to_target(output, max_det=300):
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
targets = []
for i, o in enumerate(output):
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
j = torch.full((conf.shape[0], 1), i)
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
targets = torch.cat(targets, 0).numpy()
return targets[:, 0], targets[:, 1], targets[:, 2:]
def is_ascii(s=""):
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
s = str(s) # convert list, tuple, None, etc. to str
return len(s.encode().decode("ascii", "ignore")) == len(s)
def xyxy2xywh(x):
"""
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format.
Args:
x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
Returns:
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
y[..., 2] = x[..., 2] - x[..., 0] # width
y[..., 3] = x[..., 3] - x[..., 1] # height
return y
def xywh2xyxy(x):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
return y
def check_det_dataset(dataset, autodownload=True):
# Download, check and/or unzip dataset if not found locally
data = dataset
# Download (optional)
extract_dir = ''
# Read yaml (optional)
if isinstance(data, (str, Path)):
data = yaml_load(data, append_filename=True) # dictionary
# Checks
if isinstance(data['names'], (list, tuple)): # old array format
data['names'] = dict(enumerate(data['names'])) # convert to dict
data['nc'] = len(data['names'])
# Resolve paths
path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
DATASETS_DIR = os.path.abspath('.')
if not path.is_absolute():
path = (DATASETS_DIR / path).resolve()
data['path'] = path # download scripts
for k in 'train', 'val', 'test':
if data.get(k): # prepend path
if isinstance(data[k], str):
x = (path / data[k]).resolve()
if not x.exists() and data[k].startswith('../'):
x = (path / data[k][3:]).resolve()
data[k] = str(x)
else:
data[k] = [str((path / x).resolve()) for x in data[k]]
# Parse yaml
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
if val:
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
if not all(x.exists() for x in val):
msg = f"\nDataset '{dataset}' not found ⚠️, missing paths %s" % [str(x) for x in val if not x.exists()]
if s and autodownload:
LOGGER.warning(msg)
else:
raise FileNotFoundError(msg)
t = time.time()
if s.startswith('bash '): # bash script
LOGGER.info(f'Running {s} ...')
r = os.system(s)
else: # python script
r = exec(s, {'yaml': data}) # return None
dt = f'({round(time.time() - t, 1)}s)'
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
LOGGER.info(f"Dataset download {s}\n")
return data # dictionary
def yaml_load(file='data.yaml', append_filename=False):
"""
Load YAML data from a file.
Args:
file (str, optional): File name. Default is 'data.yaml'.
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
Returns:
dict: YAML data and file name.
"""
with open(file, errors='ignore', encoding='utf-8') as f:
# Add YAML filename to dict and return
s = f.read() # string
if not s.isprintable(): # remove special characters
s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
return {**yaml.safe_load(s), 'yaml_file': str(file)} if append_filename else yaml.safe_load(s)
class IterableSimpleNamespace(SimpleNamespace):
"""
Iterable SimpleNamespace class to allow SimpleNamespace to be used with dict() and in for loops
"""
def __iter__(self):
return iter(vars(self).items())
def __str__(self):
return '\n'.join(f"{k}={v}" for k, v in vars(self).items())
def get(self, key, default=None):
return getattr(self, key, default)
def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
*args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
colors = {
"black": "\033[30m", # basic colors
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
"bright_black": "\033[90m", # bright colors
"bright_red": "\033[91m",
"bright_green": "\033[92m",
"bright_yellow": "\033[93m",
"bright_blue": "\033[94m",
"bright_magenta": "\033[95m",
"bright_cyan": "\033[96m",
"bright_white": "\033[97m",
"end": "\033[0m", # misc
"bold": "\033[1m",
"underline": "\033[4m"}
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
def seed_worker(worker_id):
# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed)
random.seed(worker_seed)
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode="train"):
assert mode in ["train", "val"]
shuffle = mode == "train"
if cfg.rect and shuffle:
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
dataset = YOLODataset(
img_path=img_path,
imgsz=cfg.imgsz,
batch_size=batch,
augment=mode == "train", # augmentation
hyp=cfg,
rect=cfg.rect or rect, # rectangular batches
cache=cfg.cache or None,
single_cls=cfg.single_cls or False,
stride=int(stride),
pad=0.0 if mode == "train" else 0.5,
prefix=colorstr(f"{mode}: "),
use_segments=cfg.task == "segment",
use_keypoints=cfg.task == "keypoint",
names=names)
batch = min(batch, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
workers = cfg.workers if mode == "train" else cfg.workers * 2
nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
if rank == -1:
sampler = None
if cfg.image_weights or cfg.close_mosaic:
loader = DataLoader
generator = torch.Generator()
generator.manual_seed(6148914691236517205)
return loader(dataset=dataset,
batch_size=batch,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
pin_memory=PIN_MEMORY,
collate_fn=getattr(dataset, "collate_fn", None),
worker_init_fn=seed_worker,
generator=generator), dataset
class BaseDataset(Dataset):
"""Base Dataset.
Args:
img_path (str): image path.
pipeline (dict): a dict of image transforms.
label_path (str): label path, this can also be an ann_file or other custom label path.
"""
def __init__(
self,
img_path,
imgsz=640,
cache=False,
augment=True,
hyp=None,
prefix="",
rect=False,
batch_size=None,
stride=32,
pad=0.5,
single_cls=False,
):
super().__init__()
self.img_path = img_path
self.imgsz = imgsz
self.augment = augment
self.single_cls = single_cls
self.prefix = prefix
self.im_files = self.get_img_files(self.img_path)
self.labels = self.get_labels()
self.ni = len(self.labels)
# rect stuff
self.rect = rect
self.batch_size = batch_size
self.stride = stride
self.pad = pad
if self.rect:
assert self.batch_size is not None
self.set_rectangle()
# cache stuff
self.ims = [None] * self.ni
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
if cache:
self.cache_images(cache)
# transforms
self.transforms = self.build_transforms(hyp=hyp)
def get_img_files(self, img_path):
"""Read image files."""
try:
f = [] # image files
for p in img_path if isinstance(img_path, list) else [img_path]:
p = Path(p) # os-agnostic
if p.is_dir(): # dir
f += glob.glob(str(p / "**" / "*.*"), recursive=True)
# f = list(p.rglob('*.*')) # pathlib
elif p.is_file(): # file
with open(p) as t:
t = t.read().strip().splitlines()
parent = str(p.parent) + os.sep
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
else:
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert im_files, f"{self.prefix}No images found"
except Exception as e:
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n") from e
return im_files
def load_image(self, i):
# Loads 1 image from dataset index 'i', returns (im, resized hw)
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
if im is None: # not cached in RAM
if fn.exists(): # load npy
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
if im is None:
raise FileNotFoundError(f"Image Not Found {f}")
h0, w0 = im.shape[:2] # orig hw
r = self.imgsz / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
def cache_images(self, cache):
# cache images to memory or disk
gb = 0 # Gigabytes of cached images
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(fcn, range(self.ni))
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT)
for i, x in pbar:
if cache == "disk":
gb += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
gb += self.ims[i].nbytes
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
pbar.close()
def cache_images_to_disk(self, i):
# Saves an image as an *.npy file for faster loading
f = self.npy_files[i]
if not f.exists():
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
def set_rectangle(self):
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches
s = np.array([x.pop("shape") for x in self.labels]) # hw
ar = s[:, 0] / s[:, 1] # aspect ratio
irect = ar.argsort()
self.im_files = [self.im_files[i] for i in irect]
self.labels = [self.labels[i] for i in irect]
ar = ar[irect]
# Set training image shapes
shapes = [[1, 1]] * nb
for i in range(nb):
ari = ar[bi == i]
mini, maxi = ari.min(), ari.max()
if maxi < 1:
shapes[i] = [maxi, 1]
elif mini > 1:
shapes[i] = [1, 1 / mini]
self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
self.batch = bi # batch index of image
def __getitem__(self, index):
return self.transforms(self.get_label_info(index))
def get_label_info(self, index):
label = self.labels[index].copy()
label.pop("shape", None) # shape is for rect, remove it
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
label["ratio_pad"] = (
label["resized_shape"][0] / label["ori_shape"][0],
label["resized_shape"][1] / label["ori_shape"][1],
) # for evaluation
if self.rect:
label["rect_shape"] = self.batch_shapes[self.batch[index]]
label = self.update_labels_info(label)
return label
def __len__(self):
return len(self.labels)
def update_labels_info(self, label):
"""custom your label format here"""
return label
def build_transforms(self, hyp=None):
"""Users can custom augmentations here
like:
if self.augment:
# training transforms
return Compose([])
else:
# val transforms
return Compose([])
"""
raise NotImplementedError
def get_labels(self):
"""Users can custom their own format here.
Make sure your output is a list with each element like below:
dict(
im_file=im_file,
shape=shape, # format: (height, width)
cls=cls,
bboxes=bboxes, # xywh
segments=segments, # xy
keypoints=keypoints, # xy
normalized=True, # or False
bbox_format="xyxy", # or xywh, ltwh
)
"""
raise NotImplementedError
def img2label_paths(img_paths):
# Define label paths as a function of image paths
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
def get_hash(paths):
# Returns a single hash value of a list of paths (files or dirs)
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
h = hashlib.md5(str(size).encode()) # hash sizes
h.update("".join(paths).encode()) # hash paths
return h.hexdigest() # return hash
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, data):
for t in self.transforms:
data = t(data)
return data
def append(self, transform):
self.transforms.append(transform)
def tolist(self):
return self.transforms
def __repr__(self):
format_string = f"{self.__class__.__name__}("
for t in self.transforms:
format_string += "\n"
format_string += f" {t}"
format_string += "\n)"
return format_string
class Format:
def __init__(self,
bbox_format="xywh",
normalize=True,
return_mask=False,
return_keypoint=False,
mask_ratio=4,
mask_overlap=True,
batch_idx=True):
self.bbox_format = bbox_format
self.normalize = normalize
self.return_mask = return_mask # set False when training detection only
self.return_keypoint = return_keypoint
self.mask_ratio = mask_ratio
self.mask_overlap = mask_overlap
self.batch_idx = batch_idx # keep the batch indexes
def __call__(self, labels):
img = labels.pop("img")
h, w = img.shape[:2]
cls = labels.pop("cls")
instances = labels.pop("instances")
instances.convert_bbox(format=self.bbox_format)
instances.denormalize(w, h)
nl = len(instances)
if self.normalize:
instances.normalize(w, h)
labels["img"] = self._format_img(img)
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
if self.return_keypoint:
labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
# then we can use collate_fn
if self.batch_idx:
labels["batch_idx"] = torch.zeros(nl)
return labels
def _format_img(self, img):
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
img = torch.from_numpy(img)
return img
class Bboxes:
"""Now only numpy is supported"""
def __init__(self, bboxes, format="xyxy") -> None:
assert format in _formats
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
assert bboxes.ndim == 2
assert bboxes.shape[1] == 4
self.bboxes = bboxes
self.format = format
def convert(self, format):
assert format in _formats
if self.format == format:
return
elif self.format == "xyxy":
if format == "xywh":
bboxes = xyxy2xywh(self.bboxes)
elif self.format == "xywh":
if format == "xyxy":
bboxes = xywh2xyxy(self.bboxes)
self.bboxes = bboxes
self.format = format
def areas(self):
self.convert("xyxy")
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
def mul(self, scale):
"""
Args:
scale (tuple | List | int): the scale for four coords.
"""
assert isinstance(scale, (tuple, list))
assert len(scale) == 4
self.bboxes[:, 0] *= scale[0]
self.bboxes[:, 1] *= scale[1]
self.bboxes[:, 2] *= scale[2]
self.bboxes[:, 3] *= scale[3]
def add(self, offset):
"""
Args:
offset (tuple | List | int): the offset for four coords.
"""
assert isinstance(offset, (tuple, list))
assert len(offset) == 4
self.bboxes[:, 0] += offset[0]
self.bboxes[:, 1] += offset[1]
self.bboxes[:, 2] += offset[2]
self.bboxes[:, 3] += offset[3]
def __len__(self):
return len(self.bboxes)
@classmethod
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
"""
Concatenates a list of Boxes into a single Bboxes
Arguments:
boxes_list (list[Bboxes])
Returns:
Bboxes: the concatenated Boxes
"""
assert isinstance(boxes_list, (list, tuple))
if not boxes_list:
return cls(np.empty(0))
assert all(isinstance(box, Bboxes) for box in boxes_list)
if len(boxes_list) == 1:
return boxes_list[0]
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
def __getitem__(self, index) -> "Bboxes":
"""
Args:
index: int, slice, or a BoolArray
Returns:
Bboxes: Create a new :class:`Bboxes` by indexing.
"""
if isinstance(index, int):
return Bboxes(self.bboxes[index].view(1, -1))
b = self.bboxes[index]
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
return Bboxes(b)
def resample_segments(segments, n=1000):
"""
Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
Args:
segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
n (int): number of points to resample the segment to. Defaults to 1000
Returns:
segments (list): the resampled segments.
"""
for i, s in enumerate(segments):
s = np.concatenate((s, s[0:1, :]), axis=0)
x = np.linspace(0, len(s) - 1, n)
xp = np.arange(len(s))
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
return segments
class Instances:
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
"""
Args:
bboxes (ndarray): bboxes with shape [N, 4].
segments (list | ndarray): segments.
keypoints (ndarray): keypoints with shape [N, 17, 2].
"""
if segments is None:
segments = []
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
self.keypoints = keypoints
self.normalized = normalized
if len(segments) > 0:
# list[np.array(1000, 2)] * num_samples
segments = resample_segments(segments)
# (N, 1000, 2)
segments = np.stack(segments, axis=0)
else:
segments = np.zeros((0, 1000, 2), dtype=np.float32)
self.segments = segments
def convert_bbox(self, format):
self._bboxes.convert(format=format)
def bbox_areas(self):
self._bboxes.areas()
def scale(self, scale_w, scale_h, bbox_only=False):
"""this might be similar with denormalize func but without normalized sign"""
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
if bbox_only:
return
self.segments[..., 0] *= scale_w
self.segments[..., 1] *= scale_h
if self.keypoints is not None:
self.keypoints[..., 0] *= scale_w
self.keypoints[..., 1] *= scale_h
def denormalize(self, w, h):
if not self.normalized:
return
self._bboxes.mul(scale=(w, h, w, h))
self.segments[..., 0] *= w
self.segments[..., 1] *= h
if self.keypoints is not None:
self.keypoints[..., 0] *= w
self.keypoints[..., 1] *= h
self.normalized = False
def normalize(self, w, h):
if self.normalized:
return
self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
self.segments[..., 0] /= w
self.segments[..., 1] /= h
if self.keypoints is not None:
self.keypoints[..., 0] /= w
self.keypoints[..., 1] /= h
self.normalized = True
def add_padding(self, padw, padh):
# handle rect and mosaic situation
assert not self.normalized, "you should add padding with absolute coordinates."
self._bboxes.add(offset=(padw, padh, padw, padh))
self.segments[..., 0] += padw
self.segments[..., 1] += padh
if self.keypoints is not None:
self.keypoints[..., 0] += padw
self.keypoints[..., 1] += padh
def __getitem__(self, index) -> "Instances":
"""
Args:
index: int, slice, or a BoolArray
Returns:
Instances: Create a new :class:`Instances` by indexing.
"""
segments = self.segments[index] if len(self.segments) else self.segments
keypoints = self.keypoints[index] if self.keypoints is not None else None
bboxes = self.bboxes[index]
bbox_format = self._bboxes.format
return Instances(
bboxes=bboxes,
segments=segments,
keypoints=keypoints,
bbox_format=bbox_format,
normalized=self.normalized,
)
def flipud(self, h):
if self._bboxes.format == "xyxy":
y1 = self.bboxes[:, 1].copy()
y2 = self.bboxes[:, 3].copy()
self.bboxes[:, 1] = h - y2
self.bboxes[:, 3] = h - y1
else:
self.bboxes[:, 1] = h - self.bboxes[:, 1]
self.segments[..., 1] = h - self.segments[..., 1]
if self.keypoints is not None:
self.keypoints[..., 1] = h - self.keypoints[..., 1]
def fliplr(self, w):
if self._bboxes.format == "xyxy":
x1 = self.bboxes[:, 0].copy()
x2 = self.bboxes[:, 2].copy()
self.bboxes[:, 0] = w - x2
self.bboxes[:, 2] = w - x1
else:
self.bboxes[:, 0] = w - self.bboxes[:, 0]
self.segments[..., 0] = w - self.segments[..., 0]
if self.keypoints is not None:
self.keypoints[..., 0] = w - self.keypoints[..., 0]
def clip(self, w, h):
ori_format = self._bboxes.format
self.convert_bbox(format="xyxy")
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
if ori_format != "xyxy":
self.convert_bbox(format=ori_format)
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
if self.keypoints is not None:
self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
def update(self, bboxes, segments=None, keypoints=None):
new_bboxes = Bboxes(bboxes, format=self._bboxes.format)
self._bboxes = new_bboxes
if segments is not None:
self.segments = segments
if keypoints is not None:
self.keypoints = keypoints
def __len__(self):
return len(self.bboxes)
@classmethod
def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
"""
Concatenates a list of Boxes into a single Bboxes
Arguments:
instances_list (list[Bboxes])
axis
Returns:
Boxes: the concatenated Boxes
"""
assert isinstance(instances_list, (list, tuple))
if not instances_list:
return cls(np.empty(0))
assert all(isinstance(instance, Instances) for instance in instances_list)
if len(instances_list) == 1:
return instances_list[0]
use_keypoint = instances_list[0].keypoints is not None
bbox_format = instances_list[0]._bboxes.format
normalized = instances_list[0].normalized
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)
@property
def bboxes(self):
return self._bboxes.bboxes
def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
"""
Check if a directory is writeable.
Args:
dir_path (str) or (Path): The path to the directory.
Returns:
bool: True if the directory is writeable, False otherwise.
"""
try:
with tempfile.TemporaryFile(dir=dir_path):
pass
return True
except OSError:
return False
class YOLODataset(BaseDataset):
cache_version = '1.0.1' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
"""YOLO Dataset.
Args:
img_path (str): image path.
prefix (str): prefix.
"""
def __init__(self,
img_path,
imgsz=640,
cache=False,
augment=True,
hyp=None,
prefix="",
rect=False,
batch_size=None,
stride=32,
pad=0.0,
single_cls=False,
use_segments=False,
use_keypoints=False,
names=None):
self.use_segments = use_segments
self.use_keypoints = use_keypoints
self.names = names
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
def cache_labels(self, path=Path("./labels.cache")):
# Cache dataset labels, check images and read shapes
if path.exists():
path.unlink() # remove *.cache file if exists
x = {"labels": []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
total = len(self.im_files)
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(func=verify_image_label,
iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
repeat(self.use_keypoints), repeat(len(self.names))))
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
nf += nf_f
ne += ne_f
nc += nc_f
if im_file:
x["labels"].append(
dict(
im_file=im_file,
shape=shape,
cls=lb[:, 0:1], # n, 1
bboxes=lb[:, 1:], # n, 4
segments=segments,
keypoints=keypoint,
normalized=True,
bbox_format="xywh"))
if msg:
msgs.append(msg)
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close()
if msgs:
LOGGER.info("\n".join(msgs))
x["hash"] = get_hash(self.label_files + self.im_files)
x["results"] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings
x["version"] = self.cache_version # cache version
self.im_files = [lb["im_file"] for lb in x["labels"]] # update im_files
if is_dir_writeable(path.parent):
np.save(str(path), x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
LOGGER.info(f"{self.prefix}New cache created: {path}")
else:
LOGGER.warning(f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable") # not writeable
return x
def get_labels(self):
self.label_files = img2label_paths(self.im_files)
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
try:
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
assert cache["version"] == self.cache_version # matches current version
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
except (FileNotFoundError, AssertionError, AttributeError):
cache, exists = self.cache_labels(cache_path), False # run cache ops
# Display cache
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists:
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
# Read cache
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
labels = cache["labels"]
# Check if the dataset is all boxes or all segments
len_cls = sum(len(lb["cls"]) for lb in labels)
len_boxes = sum(len(lb["bboxes"]) for lb in labels)
len_segments = sum(len(lb["segments"]) for lb in labels)
if len_segments and len_boxes != len_segments:
LOGGER.warning(
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")
for lb in labels:
lb["segments"] = []
return labels
def build_transforms(self, hyp=None):
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
transforms.append(
Format(bbox_format="xywh",
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask))
return transforms
def close_mosaic(self, hyp):
hyp.mosaic = 0.0 # set mosaic ratio=0.0
hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
self.transforms = self.build_transforms(hyp)
def update_labels_info(self, label):
"""custom your label format here"""
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
bboxes = label.pop("bboxes")
segments = label.pop("segments")
keypoints = label.pop("keypoints", None)
bbox_format = label.pop("bbox_format")
normalized = label.pop("normalized")
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
return label
@staticmethod
def collate_fn(batch):
new_batch = {}
keys = batch[0].keys()
values = list(zip(*[list(b.values()) for b in batch]))
for i, k in enumerate(keys):
value = values[i]
if k == "img":
value = torch.stack(value, 0)
if k in ["masks", "keypoints", "bboxes", "cls"]:
value = torch.cat(value, 0)
new_batch[k] = value
new_batch["batch_idx"] = list(new_batch["batch_idx"])
for i in range(len(new_batch["batch_idx"])):
new_batch["batch_idx"][i] += i # add target image index for build_targets()
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
return new_batch
class DFL(nn.Module):
# Integral module of Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
def __init__(self, c1=16):
super().__init__()
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
x = torch.arange(c1, dtype=torch.float)
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
self.c1 = c1
def forward(self, x):
b, c, a = x.shape # batch, channels, anchors
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(
b, 4, a
)
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
"""Transform distance(ltrb) to box(xywh or xyxy)."""
lt, rb = torch.split(distance, 2, dim)
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
if xywh:
c_xy = (x1y1 + x2y2) / 2
wh = x2y2 - x1y1
return torch.cat((c_xy, wh), dim) # xywh bbox
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
def post_process(x):
dfl = DFL(16)
anchors = torch.tensor(
np.load(
"./anchors.npy",
allow_pickle=True,
)
)
strides = torch.tensor(
np.load(
"./strides.npy",
allow_pickle=True,
)
)
box, cls = torch.cat([xi.view(x[0].shape[0], 144, -1) for xi in x], 2).split(
(16 * 4, 80), 1
)
dbox = dist2bbox(dfl(box), anchors.unsqueeze(0), xywh=True, dim=1) * strides
y = torch.cat((dbox, cls.sigmoid()), 1)
return y, x
def smooth(y, f=0.05):
# Box filter of fraction f
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
p = np.ones(nf // 2) # ones padding
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
def compute_ap(recall, precision):
""" Compute the average precision, given the recall and precision curves
# Arguments
recall: The recall curve (list)
precision: The precision curve (list)
# Returns
Average precision, precision curve, recall curve
"""
# Append sentinel values to beginning and end
mrec = np.concatenate(([0.0], recall, [1.0]))
mpre = np.concatenate(([1.0], precision, [0.0]))
# Compute the precision envelope
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
# Integrate area under curve
method = 'interp' # methods: 'continuous', 'interp'
if method == 'interp':
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
else: # 'continuous'
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x-axis (recall) changes
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
return ap, mpre, mrec
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=""):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
tp: True positives (nparray, nx1 or nx10).
conf: Objectness value from 0-1 (nparray).
pred_cls: Predicted object classes (nparray).
target_cls: True object classes (nparray).
plot: Plot precision-recall curve at mAP@0.5
save_dir: Plot save directory
# Returns
The average precision as computed in py-faster-rcnn.
"""
# Sort by objectness
i = np.argsort(-conf)
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
# Find unique classes
unique_classes, nt = np.unique(target_cls, return_counts=True)
nc = unique_classes.shape[0] # number of classes, number of detections
# Create Precision-Recall curve and compute AP for each class
px, py = np.linspace(0, 1, 1000), [] # for plotting
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
for ci, c in enumerate(unique_classes):
i = pred_cls == c
n_l = nt[ci] # number of labels
n_p = i.sum() # number of predictions
if n_p == 0 or n_l == 0:
continue
# Accumulate FPs and TPs
fpc = (1 - tp[i]).cumsum(0)
tpc = tp[i].cumsum(0)
# Recall
recall = tpc / (n_l + eps) # recall curve
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
# Precision
precision = tpc / (tpc + fpc) # precision curve
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
# AP from recall-precision curve
for j in range(tp.shape[1]):
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
if plot and j == 0:
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
# Compute F1 (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + eps)
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
names = dict(enumerate(names)) # to dict
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
p, r, f1 = p[:, i], r[:, i], f1[:, i]
tp = (r * nt).round() # true positives
fp = (tp / (p + eps) - tp).round() # false positives
return tp, fp, p, r, f1, ap, unique_classes.astype(int)
class Metric:
def __init__(self) -> None:
self.p = [] # (nc, )
self.r = [] # (nc, )
self.f1 = [] # (nc, )
self.all_ap = [] # (nc, 10)
self.ap_class_index = [] # (nc, )
self.nc = 0
@property
def ap50(self):
"""AP@0.5 of all classes.
Return:
(nc, ) or [].
"""
return self.all_ap[:, 0] if len(self.all_ap) else []
@property
def ap(self):
"""AP@0.5:0.95
Return:
(nc, ) or [].
"""
return self.all_ap.mean(1) if len(self.all_ap) else []
@property
def mp(self):
"""mean precision of all classes.
Return:
float.
"""
return self.p.mean() if len(self.p) else 0.0
@property
def mr(self):
"""mean recall of all classes.
Return:
float.
"""
return self.r.mean() if len(self.r) else 0.0
@property
def map50(self):
"""Mean AP@0.5 of all classes.
Return:
float.
"""
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
@property
def map75(self):
"""Mean AP@0.75 of all classes.
Return:
float.
"""
return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
@property
def map(self):
"""Mean AP@0.5:0.95 of all classes.
Return:
float.
"""
return self.all_ap.mean() if len(self.all_ap) else 0.0
def mean_results(self):
"""Mean of results, return mp, mr, map50, map"""
return [self.mp, self.mr, self.map50, self.map]
def class_result(self, i):
"""class-aware result, return p[i], r[i], ap50[i], ap[i]"""
return self.p[i], self.r[i], self.ap50[i], self.ap[i]
@property
def maps(self):
"""mAP of each class"""
maps = np.zeros(self.nc) + self.map
for i, c in enumerate(self.ap_class_index):
maps[c] = self.ap[i]
return maps
def fitness(self):
# Model fitness as a weighted combination of metrics
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
return (np.array(self.mean_results()) * w).sum()
def update(self, results):
"""
Args:
results: tuple(p, r, ap, f1, ap_class)
"""
self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results
class DetMetrics:
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
self.save_dir = save_dir
self.plot = plot
self.names = names
self.box = Metric()
def process(self, tp, conf, pred_cls, target_cls):
results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
names=self.names)[2:]
self.box.nc = len(self.names)
self.box.update(results)
@property
def keys(self):
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
def mean_results(self):
return self.box.mean_results()
def class_result(self, i):
return self.box.class_result(i)
@property
def maps(self):
return self.box.maps
@property
def fitness(self):
return self.box.fitness()
@property
def ap_class_index(self):
return self.box.ap_class_index
@property
def results_dict(self):
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
def increment_path(path, exist_ok=False, sep='', mkdir=False):
"""
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
directory if it does not already exist.
Args:
path (str or pathlib.Path): Path to increment.
exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False.
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string.
mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False.
Returns:
pathlib.Path: Incremented path.
"""
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
# Method 1
for n in range(2, 9999):
p = f'{path}{sep}{n}{suffix}' # increment path
if not os.path.exists(p): #
break
path = Path(p)
if mkdir:
path.mkdir(parents=True, exist_ok=True) # make directory
return path
def cfg2dict(cfg):
"""
Convert a configuration object to a dictionary.
This function converts a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
Inputs:
cfg (str) or (Path) or (SimpleNamespace): Configuration object to be converted to a dictionary.
Returns:
cfg (dict): Configuration object in dictionary format.
"""
if isinstance(cfg, (str, Path)):
cfg = yaml_load(cfg) # load dict
elif isinstance(cfg, SimpleNamespace):
cfg = vars(cfg) # convert to dict
return cfg
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = None, overrides: Dict = None):
"""
Load and merge configuration data from a file or dictionary.
Args:
cfg (str) or (Path) or (Dict) or (SimpleNamespace): Configuration data.
overrides (str) or (Dict), optional: Overrides in the form of a file name or a dictionary. Default is None.
Returns:
(SimpleNamespace): Training arguments namespace.
"""
cfg = cfg2dict(cfg)
# Merge overrides
if overrides:
overrides = cfg2dict(overrides)
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
# Special handling for numeric project/names
for k in 'project', 'name':
if k in cfg and isinstance(cfg[k], (int, float)):
cfg[k] = str(cfg[k])
# Type and Value checks
for k, v in cfg.items():
if v is not None: # None values may be from optional args
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
elif k in CFG_FRACTION_KEYS:
if not isinstance(v, (int, float)):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
if not (0.0 <= v <= 1.0):
raise ValueError(f"'{k}={v}' is an invalid value. "
f"Valid '{k}' values are between 0.0 and 1.0.")
elif k in CFG_INT_KEYS and not isinstance(v, int):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"'{k}' must be an int (i.e. '{k}=0')")
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
# Return instance
return IterableSimpleNamespace(**cfg)
def clip_boxes(boxes, shape):
"""
It takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the
shape
Args:
boxes (torch.Tensor): the bounding boxes to clip
shape (tuple): the shape of the image
"""
if isinstance(boxes, torch.Tensor): # faster individually
boxes[..., 0].clamp_(0, shape[1]) # x1
boxes[..., 1].clamp_(0, shape[0]) # y1
boxes[..., 2].clamp_(0, shape[1]) # x2
boxes[..., 3].clamp_(0, shape[0]) # y2
else: # np.array (faster grouped)
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
"""
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
(img1_shape) to the shape of a different image (img0_shape).
Args:
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
img0_shape (tuple): the shape of the target image, in the format of (height, width).
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
calculated based on the size difference between the two images.
Returns:
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
"""
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
boxes[..., [0, 2]] -= pad[0] # x padding
boxes[..., [1, 3]] -= pad[1] # y padding
boxes[..., :4] /= gain
clip_boxes(boxes, img0_shape)
return boxes
def exif_size(img):
# Returns exif-corrected PIL size
s = img.size # (width, height)
with contextlib.suppress(Exception):
rotation = dict(img._getexif().items())[orientation]
if rotation in [6, 8]: # rotation 270 or 90
s = (s[1], s[0])
return s
def verify_image_label(args):
# Verify one image-label pair
im_file, lb_file, prefix, keypoint, num_cls = args
# number (missing, found, empty, corrupt), message, segments, keypoints
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
try:
# verify images
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
if im.format.lower() in ("jpg", "jpeg"):
with open(im_file, "rb") as f:
f.seek(-2, 2)
# verify labels
if os.path.isfile(lb_file):
nf = 1 # label found
with open(lb_file) as f:
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
classes = np.array([x[0] for x in lb], dtype=np.float32)
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
lb = np.array(lb, dtype=np.float32)
nl = len(lb)
if nl:
if keypoint:
assert lb.shape[1] == 56, "labels require 56 columns each"
assert (lb[:, 5::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
kpts = np.zeros((lb.shape[0], 39))
for i in range(len(lb)):
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, 3)) # remove occlusion param from GT
kpts[i] = np.hstack((lb[i, :5], kpt))
lb = kpts
assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter"
else:
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
assert (lb[:, 1:] <= 1).all(), \
f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
# All labels
max_cls = int(lb[:, 0].max()) # max label count
assert max_cls <= num_cls, \
f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
f'Possible class labels are 0-{num_cls - 1}'
assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
_, i = np.unique(lb, axis=0, return_index=True)
if len(i) < nl: # duplicate row check
lb = lb[i] # remove duplicates
if segments:
segments = [segments[x] for x in i]
msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
else:
ne = 1 # label empty
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
else:
nm = 1 # label missing
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
if keypoint:
keypoints = lb[:, 5:].reshape(-1, 17, 2)
lb = lb[:, :5]
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
except Exception as e:
nc = 1
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
return [None, None, None, None, None, nm, nf, ne, nc, msg]