Spaces:
Build error
Build error
# Ultralytics YOLO π, GPL-3.0 license | |
import contextlib | |
import hashlib | |
import os | |
import subprocess | |
import time | |
from pathlib import Path | |
from tarfile import is_tarfile | |
from zipfile import is_zipfile | |
import cv2 | |
import numpy as np | |
import torch | |
from PIL import ExifTags, Image, ImageOps | |
from ultralytics.yolo.utils import LOGGER, ROOT, colorstr, yaml_load | |
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii | |
from ultralytics.yolo.utils.downloads import download | |
from ultralytics.yolo.utils.files import unzip_file | |
from ..utils.ops import segments2boxes | |
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data" | |
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes | |
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes | |
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html | |
RANK = int(os.getenv('RANK', -1)) | |
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders | |
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean | |
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation | |
# Get orientation exif tag | |
for orientation in ExifTags.TAGS.keys(): | |
if ExifTags.TAGS[orientation] == "Orientation": | |
break | |
def img2label_paths(img_paths): | |
# Define label paths as a function of image paths | |
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings | |
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths] | |
def get_hash(paths): | |
# Returns a single hash value of a list of paths (files or dirs) | |
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes | |
h = hashlib.md5(str(size).encode()) # hash sizes | |
h.update("".join(paths).encode()) # hash paths | |
return h.hexdigest() # return hash | |
def exif_size(img): | |
# Returns exif-corrected PIL size | |
s = img.size # (width, height) | |
with contextlib.suppress(Exception): | |
rotation = dict(img._getexif().items())[orientation] | |
if rotation in [6, 8]: # rotation 270 or 90 | |
s = (s[1], s[0]) | |
return s | |
def verify_image_label(args): | |
# Verify one image-label pair | |
im_file, lb_file, prefix, keypoint = args | |
# number (missing, found, empty, corrupt), message, segments, keypoints | |
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None | |
try: | |
# verify images | |
im = Image.open(im_file) | |
im.verify() # PIL verify | |
shape = exif_size(im) # image size | |
shape = (shape[1], shape[0]) # hw | |
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" | |
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" | |
if im.format.lower() in ("jpg", "jpeg"): | |
with open(im_file, "rb") as f: | |
f.seek(-2, 2) | |
if f.read() != b"\xff\xd9": # corrupt JPEG | |
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) | |
msg = f"{prefix}WARNING β οΈ {im_file}: corrupt JPEG restored and saved" | |
# verify labels | |
if os.path.isfile(lb_file): | |
nf = 1 # label found | |
with open(lb_file) as f: | |
lb = [x.split() for x in f.read().strip().splitlines() if len(x)] | |
if any(len(x) > 6 for x in lb) and (not keypoint): # is segment | |
classes = np.array([x[0] for x in lb], dtype=np.float32) | |
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...) | |
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) | |
lb = np.array(lb, dtype=np.float32) | |
nl = len(lb) | |
if nl: | |
if keypoint: | |
assert lb.shape[1] == 56, "labels require 56 columns each" | |
assert (lb[:, 5::3] <= 1).all(), "non-normalized or out of bounds coordinate labels" | |
assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels" | |
kpts = np.zeros((lb.shape[0], 39)) | |
for i in range(len(lb)): | |
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, | |
3)) # remove the occlusion parameter from the GT | |
kpts[i] = np.hstack((lb[i, :5], kpt)) | |
lb = kpts | |
assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter" | |
else: | |
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" | |
assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}" | |
assert (lb[:, 1:] <= | |
1).all(), f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}" | |
_, i = np.unique(lb, axis=0, return_index=True) | |
if len(i) < nl: # duplicate row check | |
lb = lb[i] # remove duplicates | |
if segments: | |
segments = [segments[x] for x in i] | |
msg = f"{prefix}WARNING β οΈ {im_file}: {nl - len(i)} duplicate labels removed" | |
else: | |
ne = 1 # label empty | |
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32) | |
else: | |
nm = 1 # label missing | |
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32) | |
if keypoint: | |
keypoints = lb[:, 5:].reshape(-1, 17, 2) | |
lb = lb[:, :5] | |
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg | |
except Exception as e: | |
nc = 1 | |
msg = f"{prefix}WARNING β οΈ {im_file}: ignoring corrupt image/label: {e}" | |
return [None, None, None, None, None, nm, nf, ne, nc, msg] | |
def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1): | |
""" | |
Args: | |
imgsz (tuple): The image size. | |
polygons (np.ndarray): [N, M], N is the number of polygons, M is the number of points(Be divided by 2). | |
color (int): color | |
downsample_ratio (int): downsample ratio | |
""" | |
mask = np.zeros(imgsz, dtype=np.uint8) | |
polygons = np.asarray(polygons) | |
polygons = polygons.astype(np.int32) | |
shape = polygons.shape | |
polygons = polygons.reshape(shape[0], -1, 2) | |
cv2.fillPoly(mask, polygons, color=color) | |
nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio) | |
# NOTE: fillPoly firstly then resize is trying the keep the same way | |
# of loss calculation when mask-ratio=1. | |
mask = cv2.resize(mask, (nw, nh)) | |
return mask | |
def polygons2masks(imgsz, polygons, color, downsample_ratio=1): | |
""" | |
Args: | |
imgsz (tuple): The image size. | |
polygons (list[np.ndarray]): each polygon is [N, M], N is number of polygons, M is number of points (M % 2 = 0) | |
color (int): color | |
downsample_ratio (int): downsample ratio | |
""" | |
masks = [] | |
for si in range(len(polygons)): | |
mask = polygon2mask(imgsz, [polygons[si].reshape(-1)], color, downsample_ratio) | |
masks.append(mask) | |
return np.array(masks) | |
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): | |
"""Return a (640, 640) overlap mask.""" | |
masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio), | |
dtype=np.int32 if len(segments) > 255 else np.uint8) | |
areas = [] | |
ms = [] | |
for si in range(len(segments)): | |
mask = polygon2mask( | |
imgsz, | |
[segments[si].reshape(-1)], | |
downsample_ratio=downsample_ratio, | |
color=1, | |
) | |
ms.append(mask) | |
areas.append(mask.sum()) | |
areas = np.asarray(areas) | |
index = np.argsort(-areas) | |
ms = np.array(ms)[index] | |
for i in range(len(segments)): | |
mask = ms[i] * (i + 1) | |
masks = masks + mask | |
masks = np.clip(masks, a_min=0, a_max=i + 1) | |
return masks, index | |
def check_dataset_yaml(data, autodownload=True): | |
# Download, check and/or unzip dataset if not found locally | |
data = check_file(data) | |
DATASETS_DIR = (Path.cwd() / "../datasets").resolve() # TODO: handle global dataset dir | |
# Download (optional) | |
extract_dir = '' | |
if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)): | |
download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1) | |
data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml')) | |
extract_dir, autodownload = data.parent, False | |
# Read yaml (optional) | |
if isinstance(data, (str, Path)): | |
data = yaml_load(data, append_filename=True) # dictionary | |
# Checks | |
for k in 'train', 'val', 'names': | |
assert k in data, f"data.yaml '{k}:' field missing β" | |
if isinstance(data['names'], (list, tuple)): # old array format | |
data['names'] = dict(enumerate(data['names'])) # convert to dict | |
data['nc'] = len(data['names']) | |
# Resolve paths | |
path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.' | |
if not path.is_absolute(): | |
path = (Path.cwd() / path).resolve() | |
data['path'] = path # download scripts | |
for k in 'train', 'val', 'test': | |
if data.get(k): # prepend path | |
if isinstance(data[k], str): | |
x = (path / data[k]).resolve() | |
if not x.exists() and data[k].startswith('../'): | |
x = (path / data[k][3:]).resolve() | |
data[k] = str(x) | |
else: | |
data[k] = [str((path / x).resolve()) for x in data[k]] | |
# Parse yaml | |
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download')) | |
if val: | |
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path | |
if not all(x.exists() for x in val): | |
LOGGER.info('\nDataset not found β οΈ, missing paths %s' % [str(x) for x in val if not x.exists()]) | |
if not s or not autodownload: | |
raise FileNotFoundError('Dataset not found β') | |
t = time.time() | |
if s.startswith('http') and s.endswith('.zip'): # URL | |
f = Path(s).name # filename | |
LOGGER.info(f'Downloading {s} to {f}...') | |
torch.hub.download_url_to_file(s, f) | |
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root | |
unzip_file(f, path=DATASETS_DIR) # unzip | |
Path(f).unlink() # remove zip | |
r = None # success | |
elif s.startswith('bash '): # bash script | |
LOGGER.info(f'Running {s} ...') | |
r = os.system(s) | |
else: # python script | |
r = exec(s, {'yaml': data}) # return None | |
dt = f'({round(time.time() - t, 1)}s)' | |
s = f"success β {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} β" | |
LOGGER.info(f"Dataset download {s}") | |
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts | |
return data # dictionary | |
def check_dataset(dataset: str): | |
""" | |
Check a classification dataset such as Imagenet. | |
Copy code | |
This function takes a `dataset` name as input and returns a dictionary containing information about the dataset. | |
If the dataset is not found, it attempts to download the dataset from the internet and save it to the local file system. | |
Args: | |
dataset (str): Name of the dataset. | |
Returns: | |
data (dict): A dictionary containing the following keys and values: | |
'train': Path object for the directory containing the training set of the dataset | |
'val': Path object for the directory containing the validation set of the dataset | |
'nc': Number of classes in the dataset | |
'names': List of class names in the dataset | |
""" | |
data_dir = (Path.cwd() / "datasets" / dataset).resolve() | |
if not data_dir.is_dir(): | |
LOGGER.info(f'\nDataset not found β οΈ, missing path {data_dir}, attempting download...') | |
t = time.time() | |
if dataset == 'imagenet': | |
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) | |
else: | |
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' | |
download(url, dir=data_dir.parent) | |
s = f"Dataset download success β ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" | |
LOGGER.info(s) | |
train_set = data_dir / "train" | |
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val | |
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes | |
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list | |
names = dict(enumerate(sorted(names))) | |
return {"train": train_set, "val": test_set, "nc": nc, "names": names} | |