Spaces:
Runtime error
Runtime error
| # Ultralytics YOLO 🚀, AGPL-3.0 license | |
| import contextlib | |
| import glob | |
| import inspect | |
| import math | |
| import os | |
| import platform | |
| import re | |
| import shutil | |
| import subprocess | |
| import sys | |
| import time | |
| from importlib import metadata | |
| from pathlib import Path | |
| from typing import Optional | |
| import cv2 | |
| import numpy as np | |
| import requests | |
| import torch | |
| from matplotlib import font_manager | |
| from ultralytics.utils import (ASSETS, AUTOINSTALL, LINUX, LOGGER, ONLINE, ROOT, USER_CONFIG_DIR, SimpleNamespace, | |
| ThreadingLocked, TryExcept, clean_url, colorstr, downloads, emojis, is_colab, is_docker, | |
| is_github_action_running, is_jupyter, is_kaggle, is_online, is_pip_package, url2file) | |
| def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''): | |
| """ | |
| Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'. | |
| Args: | |
| file_path (Path): Path to the requirements.txt file. | |
| package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'. | |
| Returns: | |
| (List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys. | |
| Example: | |
| ```python | |
| from ultralytics.utils.checks import parse_requirements | |
| parse_requirements(package='ultralytics') | |
| ``` | |
| """ | |
| if package: | |
| requires = [x for x in metadata.distribution(package).requires if 'extra == ' not in x] | |
| else: | |
| requires = Path(file_path).read_text().splitlines() | |
| requirements = [] | |
| for line in requires: | |
| line = line.strip() | |
| if line and not line.startswith('#'): | |
| line = line.split('#')[0].strip() # ignore inline comments | |
| match = re.match(r'([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?', line) | |
| if match: | |
| requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else '')) | |
| return requirements | |
| def parse_version(version='0.0.0') -> tuple: | |
| """ | |
| Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This | |
| function replaces deprecated 'pkg_resources.parse_version(v)'. | |
| Args: | |
| version (str): Version string, i.e. '2.0.1+cpu' | |
| Returns: | |
| (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1) | |
| """ | |
| try: | |
| return tuple(map(int, re.findall(r'\d+', version)[:3])) # '2.0.1+cpu' -> (2, 0, 1) | |
| except Exception as e: | |
| LOGGER.warning(f'WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}') | |
| return 0, 0, 0 | |
| def is_ascii(s) -> bool: | |
| """ | |
| Check if a string is composed of only ASCII characters. | |
| Args: | |
| s (str): String to be checked. | |
| Returns: | |
| bool: True if the string is composed only of ASCII characters, False otherwise. | |
| """ | |
| # Convert list, tuple, None, etc. to string | |
| s = str(s) | |
| # Check if the string is composed of only ASCII characters | |
| return all(ord(c) < 128 for c in s) | |
| def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): | |
| """ | |
| Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the | |
| stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. | |
| Args: | |
| imgsz (int | cList[int]): Image size. | |
| stride (int): Stride value. | |
| min_dim (int): Minimum number of dimensions. | |
| max_dim (int): Maximum number of dimensions. | |
| floor (int): Minimum allowed value for image size. | |
| Returns: | |
| (List[int]): Updated image size. | |
| """ | |
| # Convert stride to integer if it is a tensor | |
| stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride) | |
| # Convert image size to list if it is an integer | |
| if isinstance(imgsz, int): | |
| imgsz = [imgsz] | |
| elif isinstance(imgsz, (list, tuple)): | |
| imgsz = list(imgsz) | |
| else: | |
| raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " | |
| f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'") | |
| # Apply max_dim | |
| if len(imgsz) > max_dim: | |
| msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \ | |
| "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" | |
| if max_dim != 1: | |
| raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}') | |
| LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}") | |
| imgsz = [max(imgsz)] | |
| # Make image size a multiple of the stride | |
| sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] | |
| # Print warning message if image size was updated | |
| if sz != imgsz: | |
| LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}') | |
| # Add missing dimensions if necessary | |
| sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz | |
| return sz | |
| def check_version(current: str = '0.0.0', | |
| required: str = '0.0.0', | |
| name: str = 'version', | |
| hard: bool = False, | |
| verbose: bool = False, | |
| msg: str = '') -> bool: | |
| """ | |
| Check current version against the required version or range. | |
| Args: | |
| current (str): Current version or package name to get version from. | |
| required (str): Required version or range (in pip-style format). | |
| name (str, optional): Name to be used in warning message. | |
| hard (bool, optional): If True, raise an AssertionError if the requirement is not met. | |
| verbose (bool, optional): If True, print warning message if requirement is not met. | |
| msg (str, optional): Extra message to display if verbose. | |
| Returns: | |
| (bool): True if requirement is met, False otherwise. | |
| Example: | |
| ```python | |
| # Check if current version is exactly 22.04 | |
| check_version(current='22.04', required='==22.04') | |
| # Check if current version is greater than or equal to 22.04 | |
| check_version(current='22.10', required='22.04') # assumes '>=' inequality if none passed | |
| # Check if current version is less than or equal to 22.04 | |
| check_version(current='22.04', required='<=22.04') | |
| # Check if current version is between 20.04 (inclusive) and 22.04 (exclusive) | |
| check_version(current='21.10', required='>20.04,<22.04') | |
| ``` | |
| """ | |
| if not current: # if current is '' or None | |
| LOGGER.warning(f'WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.') | |
| return True | |
| elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics' | |
| try: | |
| name = current # assigned package name to 'name' arg | |
| current = metadata.version(current) # get version string from package name | |
| except metadata.PackageNotFoundError: | |
| if hard: | |
| raise ModuleNotFoundError(emojis(f'WARNING ⚠️ {current} package is required but not installed')) | |
| else: | |
| return False | |
| if not required: # if required is '' or None | |
| return True | |
| op = '' | |
| version = '' | |
| result = True | |
| c = parse_version(current) # '1.2.3' -> (1, 2, 3) | |
| for r in required.strip(',').split(','): | |
| op, version = re.match(r'([^0-9]*)([\d.]+)', r).groups() # split '>=22.04' -> ('>=', '22.04') | |
| v = parse_version(version) # '1.2.3' -> (1, 2, 3) | |
| if op == '==' and c != v: | |
| result = False | |
| elif op == '!=' and c == v: | |
| result = False | |
| elif op in ('>=', '') and not (c >= v): # if no constraint passed assume '>=required' | |
| result = False | |
| elif op == '<=' and not (c <= v): | |
| result = False | |
| elif op == '>' and not (c > v): | |
| result = False | |
| elif op == '<' and not (c < v): | |
| result = False | |
| if not result: | |
| warning = f'WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}' | |
| if hard: | |
| raise ModuleNotFoundError(emojis(warning)) # assert version requirements met | |
| if verbose: | |
| LOGGER.warning(warning) | |
| return result | |
| def check_latest_pypi_version(package_name='ultralytics'): | |
| """ | |
| Returns the latest version of a PyPI package without downloading or installing it. | |
| Parameters: | |
| package_name (str): The name of the package to find the latest version for. | |
| Returns: | |
| (str): The latest version of the package. | |
| """ | |
| with contextlib.suppress(Exception): | |
| requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning | |
| response = requests.get(f'https://pypi.org/pypi/{package_name}/json', timeout=3) | |
| if response.status_code == 200: | |
| return response.json()['info']['version'] | |
| def check_pip_update_available(): | |
| """ | |
| Checks if a new version of the ultralytics package is available on PyPI. | |
| Returns: | |
| (bool): True if an update is available, False otherwise. | |
| """ | |
| if ONLINE and is_pip_package(): | |
| with contextlib.suppress(Exception): | |
| from ultralytics import __version__ | |
| latest = check_latest_pypi_version() | |
| if check_version(__version__, f'<{latest}'): # check if current version is < latest version | |
| LOGGER.info(f'New https://pypi.org/project/ultralytics/{latest} available 😃 ' | |
| f"Update with 'pip install -U ultralytics'") | |
| return True | |
| return False | |
| def check_font(font='Arial.ttf'): | |
| """ | |
| Find font locally or download to user's configuration directory if it does not already exist. | |
| Args: | |
| font (str): Path or name of font. | |
| Returns: | |
| file (Path): Resolved font file path. | |
| """ | |
| name = Path(font).name | |
| # Check USER_CONFIG_DIR | |
| file = USER_CONFIG_DIR / name | |
| if file.exists(): | |
| return file | |
| # Check system fonts | |
| matches = [s for s in font_manager.findSystemFonts() if font in s] | |
| if any(matches): | |
| return matches[0] | |
| # Download to USER_CONFIG_DIR if missing | |
| url = f'https://ultralytics.com/assets/{name}' | |
| if downloads.is_url(url): | |
| downloads.safe_download(url=url, file=file) | |
| return file | |
| def check_python(minimum: str = '3.8.0') -> bool: | |
| """ | |
| Check current python version against the required minimum version. | |
| Args: | |
| minimum (str): Required minimum version of python. | |
| Returns: | |
| None | |
| """ | |
| return check_version(platform.python_version(), minimum, name='Python ', hard=True) | |
| def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=(), install=True, cmds=''): | |
| """ | |
| Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed. | |
| Args: | |
| requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a | |
| string, or a list of package requirements as strings. | |
| exclude (Tuple[str]): Tuple of package names to exclude from checking. | |
| install (bool): If True, attempt to auto-update packages that don't meet requirements. | |
| cmds (str): Additional commands to pass to the pip install command when auto-updating. | |
| Example: | |
| ```python | |
| from ultralytics.utils.checks import check_requirements | |
| # Check a requirements.txt file | |
| check_requirements('path/to/requirements.txt') | |
| # Check a single package | |
| check_requirements('ultralytics>=8.0.0') | |
| # Check multiple packages | |
| check_requirements(['numpy', 'ultralytics>=8.0.0']) | |
| ``` | |
| """ | |
| prefix = colorstr('red', 'bold', 'requirements:') | |
| check_python() # check python version | |
| check_torchvision() # check torch-torchvision compatibility | |
| if isinstance(requirements, Path): # requirements.txt file | |
| file = requirements.resolve() | |
| assert file.exists(), f'{prefix} {file} not found, check failed.' | |
| requirements = [f'{x.name}{x.specifier}' for x in parse_requirements(file) if x.name not in exclude] | |
| elif isinstance(requirements, str): | |
| requirements = [requirements] | |
| pkgs = [] | |
| for r in requirements: | |
| r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo' | |
| match = re.match(r'([a-zA-Z0-9-_]+)([<>!=~]+.*)?', r_stripped) | |
| name, required = match[1], match[2].strip() if match[2] else '' | |
| try: | |
| assert check_version(metadata.version(name), required) # exception if requirements not met | |
| except (AssertionError, metadata.PackageNotFoundError): | |
| pkgs.append(r) | |
| s = ' '.join(f'"{x}"' for x in pkgs) # console string | |
| if s: | |
| if install and AUTOINSTALL: # check environment variable | |
| n = len(pkgs) # number of packages updates | |
| LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...") | |
| try: | |
| t = time.time() | |
| assert is_online(), 'AutoUpdate skipped (offline)' | |
| LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode()) | |
| dt = time.time() - t | |
| LOGGER.info( | |
| f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n" | |
| f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n") | |
| except Exception as e: | |
| LOGGER.warning(f'{prefix} ❌ {e}') | |
| return False | |
| else: | |
| return False | |
| return True | |
| def check_torchvision(): | |
| """ | |
| Checks the installed versions of PyTorch and Torchvision to ensure they're compatible. | |
| This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according | |
| to the provided compatibility table based on: | |
| https://github.com/pytorch/vision#installation. | |
| The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible | |
| Torchvision versions. | |
| """ | |
| import torchvision | |
| # Compatibility table | |
| compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']} | |
| # Extract only the major and minor versions | |
| v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2]) | |
| v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2]) | |
| if v_torch in compatibility_table: | |
| compatible_versions = compatibility_table[v_torch] | |
| if all(v_torchvision != v for v in compatible_versions): | |
| print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n' | |
| f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or " | |
| "'pip install -U torch torchvision' to update both.\n" | |
| 'For a full compatibility table see https://github.com/pytorch/vision#installation') | |
| def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''): | |
| """Check file(s) for acceptable suffix.""" | |
| if file and suffix: | |
| if isinstance(suffix, str): | |
| suffix = (suffix, ) | |
| for f in file if isinstance(file, (list, tuple)) else [file]: | |
| s = Path(f).suffix.lower().strip() # file suffix | |
| if len(s): | |
| assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}, not {s}' | |
| def check_yolov5u_filename(file: str, verbose: bool = True): | |
| """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.""" | |
| if 'yolov3' in file or 'yolov5' in file: | |
| if 'u.yaml' in file: | |
| file = file.replace('u.yaml', '.yaml') # i.e. yolov5nu.yaml -> yolov5n.yaml | |
| elif '.pt' in file and 'u' not in file: | |
| original_file = file | |
| file = re.sub(r'(.*yolov5([nsmlx]))\.pt', '\\1u.pt', file) # i.e. yolov5n.pt -> yolov5nu.pt | |
| file = re.sub(r'(.*yolov5([nsmlx])6)\.pt', '\\1u.pt', file) # i.e. yolov5n6.pt -> yolov5n6u.pt | |
| file = re.sub(r'(.*yolov3(|-tiny|-spp))\.pt', '\\1u.pt', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt | |
| if file != original_file and verbose: | |
| LOGGER.info( | |
| f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " | |
| f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs ' | |
| f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n') | |
| return file | |
| def check_model_file_from_stem(model='yolov8n'): | |
| """Return a model filename from a valid model stem.""" | |
| if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS: | |
| return Path(model).with_suffix('.pt') # add suffix, i.e. yolov8n -> yolov8n.pt | |
| else: | |
| return model | |
| def check_file(file, suffix='', download=True, hard=True): | |
| """Search/download file (if necessary) and return path.""" | |
| check_suffix(file, suffix) # optional | |
| file = str(file).strip() # convert to string and strip spaces | |
| file = check_yolov5u_filename(file) # yolov5n -> yolov5nu | |
| if (not file or ('://' not in file and Path(file).exists()) or # '://' check required in Windows Python<3.10 | |
| file.lower().startswith('grpc://')): # file exists or gRPC Triton images | |
| return file | |
| elif download and file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://', 'tcp://')): # download | |
| url = file # warning: Pathlib turns :// -> :/ | |
| file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth | |
| if Path(file).exists(): | |
| LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists | |
| else: | |
| downloads.safe_download(url=url, file=file, unzip=False) | |
| return file | |
| else: # search | |
| files = glob.glob(str(ROOT / 'cfg' / '**' / file), recursive=True) # find file | |
| if not files and hard: | |
| raise FileNotFoundError(f"'{file}' does not exist") | |
| elif len(files) > 1 and hard: | |
| raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") | |
| return files[0] if len(files) else [] # return file | |
| def check_yaml(file, suffix=('.yaml', '.yml'), hard=True): | |
| """Search/download YAML file (if necessary) and return path, checking suffix.""" | |
| return check_file(file, suffix, hard=hard) | |
| def check_is_path_safe(basedir, path): | |
| """ | |
| Check if the resolved path is under the intended directory to prevent path traversal. | |
| Args: | |
| basedir (Path | str): The intended directory. | |
| path (Path | str): The path to check. | |
| Returns: | |
| (bool): True if the path is safe, False otherwise. | |
| """ | |
| base_dir_resolved = Path(basedir).resolve() | |
| path_resolved = Path(path).resolve() | |
| return path_resolved.is_file() and path_resolved.parts[:len(base_dir_resolved.parts)] == base_dir_resolved.parts | |
| def check_imshow(warn=False): | |
| """Check if environment supports image displays.""" | |
| try: | |
| if LINUX: | |
| assert 'DISPLAY' in os.environ and not is_docker() and not is_colab() and not is_kaggle() | |
| cv2.imshow('test', np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image | |
| cv2.waitKey(1) | |
| cv2.destroyAllWindows() | |
| cv2.waitKey(1) | |
| return True | |
| except Exception as e: | |
| if warn: | |
| LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}') | |
| return False | |
| def check_yolo(verbose=True, device=''): | |
| """Return a human-readable YOLO software and hardware summary.""" | |
| import psutil | |
| from ultralytics.utils.torch_utils import select_device | |
| if is_jupyter(): | |
| if check_requirements('wandb', install=False): | |
| os.system('pip uninstall -y wandb') # uninstall wandb: unwanted account creation prompt with infinite hang | |
| if is_colab(): | |
| shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory | |
| if verbose: | |
| # System info | |
| gib = 1 << 30 # bytes per GiB | |
| ram = psutil.virtual_memory().total | |
| total, used, free = shutil.disk_usage('/') | |
| s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)' | |
| with contextlib.suppress(Exception): # clear display if ipython is installed | |
| from IPython import display | |
| display.clear_output() | |
| else: | |
| s = '' | |
| select_device(device=device, newline=False) | |
| LOGGER.info(f'Setup complete ✅ {s}') | |
| def collect_system_info(): | |
| """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.""" | |
| import psutil | |
| from ultralytics.utils import ENVIRONMENT, is_git_dir | |
| from ultralytics.utils.torch_utils import get_cpu_info | |
| ram_info = psutil.virtual_memory().total / (1024 ** 3) # Convert bytes to GB | |
| check_yolo() | |
| LOGGER.info(f"\n{'OS':<20}{platform.platform()}\n" | |
| f"{'Environment':<20}{ENVIRONMENT}\n" | |
| f"{'Python':<20}{sys.version.split()[0]}\n" | |
| f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n" | |
| f"{'RAM':<20}{ram_info:.2f} GB\n" | |
| f"{'CPU':<20}{get_cpu_info()}\n" | |
| f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n") | |
| for r in parse_requirements(package='ultralytics'): | |
| try: | |
| current = metadata.version(r.name) | |
| is_met = '✅ ' if check_version(current, str(r.specifier), hard=True) else '❌ ' | |
| except metadata.PackageNotFoundError: | |
| current = '(not installed)' | |
| is_met = '❌ ' | |
| LOGGER.info(f'{r.name:<20}{is_met}{current}{r.specifier}') | |
| if is_github_action_running(): | |
| LOGGER.info(f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n" | |
| f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n" | |
| f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n" | |
| f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n" | |
| f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n" | |
| f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n") | |
| def check_amp(model): | |
| """ | |
| This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks | |
| fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will | |
| be disabled during training. | |
| Args: | |
| model (nn.Module): A YOLOv8 model instance. | |
| Example: | |
| ```python | |
| from ultralytics import YOLO | |
| from ultralytics.utils.checks import check_amp | |
| model = YOLO('yolov8n.pt').model.cuda() | |
| check_amp(model) | |
| ``` | |
| Returns: | |
| (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False. | |
| """ | |
| device = next(model.parameters()).device # get model device | |
| if device.type in ('cpu', 'mps'): | |
| return False # AMP only used on CUDA devices | |
| def amp_allclose(m, im): | |
| """All close FP32 vs AMP results.""" | |
| a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference | |
| with torch.cuda.amp.autocast(True): | |
| b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference | |
| del m | |
| return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance | |
| im = ASSETS / 'bus.jpg' # image to check | |
| prefix = colorstr('AMP: ') | |
| LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...') | |
| warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." | |
| try: | |
| from ultralytics import YOLO | |
| assert amp_allclose(YOLO('yolov8n.pt'), im) | |
| LOGGER.info(f'{prefix}checks passed ✅') | |
| except ConnectionError: | |
| LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}') | |
| except (AttributeError, ModuleNotFoundError): | |
| LOGGER.warning(f'{prefix}checks skipped ⚠️. ' | |
| f'Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}') | |
| except AssertionError: | |
| LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to ' | |
| f'NaN losses or zero-mAP results, so AMP will be disabled during training.') | |
| return False | |
| return True | |
| def git_describe(path=ROOT): # path must be a directory | |
| """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.""" | |
| with contextlib.suppress(Exception): | |
| return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1] | |
| return '' | |
| def print_args(args: Optional[dict] = None, show_file=True, show_func=False): | |
| """Print function arguments (optional args dict).""" | |
| def strip_auth(v): | |
| """Clean longer Ultralytics HUB URLs by stripping potential authentication information.""" | |
| return clean_url(v) if (isinstance(v, str) and v.startswith('http') and len(v) > 100) else v | |
| x = inspect.currentframe().f_back # previous frame | |
| file, _, func, _, _ = inspect.getframeinfo(x) | |
| if args is None: # get args automatically | |
| args, _, _, frm = inspect.getargvalues(x) | |
| args = {k: v for k, v in frm.items() if k in args} | |
| try: | |
| file = Path(file).resolve().relative_to(ROOT).with_suffix('') | |
| except ValueError: | |
| file = Path(file).stem | |
| s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '') | |
| LOGGER.info(colorstr(s) + ', '.join(f'{k}={strip_auth(v)}' for k, v in args.items())) | |
| def cuda_device_count() -> int: | |
| """ | |
| Get the number of NVIDIA GPUs available in the environment. | |
| Returns: | |
| (int): The number of NVIDIA GPUs available. | |
| """ | |
| try: | |
| # Run the nvidia-smi command and capture its output | |
| output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'], | |
| encoding='utf-8') | |
| # Take the first line and strip any leading/trailing white space | |
| first_line = output.strip().split('\n')[0] | |
| return int(first_line) | |
| except (subprocess.CalledProcessError, FileNotFoundError, ValueError): | |
| # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available | |
| return 0 | |
| def cuda_is_available() -> bool: | |
| """ | |
| Check if CUDA is available in the environment. | |
| Returns: | |
| (bool): True if one or more NVIDIA GPUs are available, False otherwise. | |
| """ | |
| return cuda_device_count() > 0 | |