| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from itertools import repeat | 
					
					
						
						| 
							 | 
						from multiprocessing.pool import ThreadPool | 
					
					
						
						| 
							 | 
						from pathlib import Path | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import cv2 | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torchvision | 
					
					
						
						| 
							 | 
						from tqdm import tqdm | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from ..utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable | 
					
					
						
						| 
							 | 
						from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms | 
					
					
						
						| 
							 | 
						from .base import BaseDataset | 
					
					
						
						| 
							 | 
						from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class YOLODataset(BaseDataset): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Dataset class for loading object detection and/or segmentation labels in YOLO format. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        data (dict, optional): A dataset YAML dictionary. Defaults to None. | 
					
					
						
						| 
							 | 
						        use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False. | 
					
					
						
						| 
							 | 
						        use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    cache_version = '1.0.2'   | 
					
					
						
						| 
							 | 
						    rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs): | 
					
					
						
						| 
							 | 
						        self.use_segments = use_segments | 
					
					
						
						| 
							 | 
						        self.use_keypoints = use_keypoints | 
					
					
						
						| 
							 | 
						        self.data = data | 
					
					
						
						| 
							 | 
						        assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.' | 
					
					
						
						| 
							 | 
						        super().__init__(*args, **kwargs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def cache_labels(self, path=Path('./labels.cache')): | 
					
					
						
						| 
							 | 
						        """Cache dataset labels, check images and read shapes. | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            path (Path): path where to save the cache file (default: Path('./labels.cache')). | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            (dict): labels. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        x = {'labels': []} | 
					
					
						
						| 
							 | 
						        nm, nf, ne, nc, msgs = 0, 0, 0, 0, []   | 
					
					
						
						| 
							 | 
						        desc = f'{self.prefix}Scanning {path.parent / path.stem}...' | 
					
					
						
						| 
							 | 
						        total = len(self.im_files) | 
					
					
						
						| 
							 | 
						        nkpt, ndim = self.data.get('kpt_shape', (0, 0)) | 
					
					
						
						| 
							 | 
						        if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)): | 
					
					
						
						| 
							 | 
						            raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " | 
					
					
						
						| 
							 | 
						                             "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'") | 
					
					
						
						| 
							 | 
						        with ThreadPool(NUM_THREADS) as pool: | 
					
					
						
						| 
							 | 
						            results = pool.imap(func=verify_image_label, | 
					
					
						
						| 
							 | 
						                                iterable=zip(self.im_files, self.label_files, repeat(self.prefix), | 
					
					
						
						| 
							 | 
						                                             repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt), | 
					
					
						
						| 
							 | 
						                                             repeat(ndim))) | 
					
					
						
						| 
							 | 
						            pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT) | 
					
					
						
						| 
							 | 
						            for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: | 
					
					
						
						| 
							 | 
						                nm += nm_f | 
					
					
						
						| 
							 | 
						                nf += nf_f | 
					
					
						
						| 
							 | 
						                ne += ne_f | 
					
					
						
						| 
							 | 
						                nc += nc_f | 
					
					
						
						| 
							 | 
						                if im_file: | 
					
					
						
						| 
							 | 
						                    x['labels'].append( | 
					
					
						
						| 
							 | 
						                        dict( | 
					
					
						
						| 
							 | 
						                            im_file=im_file, | 
					
					
						
						| 
							 | 
						                            shape=shape, | 
					
					
						
						| 
							 | 
						                            cls=lb[:, 0:1],   | 
					
					
						
						| 
							 | 
						                            bboxes=lb[:, 1:],   | 
					
					
						
						| 
							 | 
						                            segments=segments, | 
					
					
						
						| 
							 | 
						                            keypoints=keypoint, | 
					
					
						
						| 
							 | 
						                            normalized=True, | 
					
					
						
						| 
							 | 
						                            bbox_format='xywh')) | 
					
					
						
						| 
							 | 
						                if msg: | 
					
					
						
						| 
							 | 
						                    msgs.append(msg) | 
					
					
						
						| 
							 | 
						                pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt' | 
					
					
						
						| 
							 | 
						            pbar.close() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if msgs: | 
					
					
						
						| 
							 | 
						            LOGGER.info('\n'.join(msgs)) | 
					
					
						
						| 
							 | 
						        if nf == 0: | 
					
					
						
						| 
							 | 
						            LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}') | 
					
					
						
						| 
							 | 
						        x['hash'] = get_hash(self.label_files + self.im_files) | 
					
					
						
						| 
							 | 
						        x['results'] = nf, nm, ne, nc, len(self.im_files) | 
					
					
						
						| 
							 | 
						        x['msgs'] = msgs   | 
					
					
						
						| 
							 | 
						        x['version'] = self.cache_version   | 
					
					
						
						| 
							 | 
						        if is_dir_writeable(path.parent): | 
					
					
						
						| 
							 | 
						            if path.exists(): | 
					
					
						
						| 
							 | 
						                path.unlink()   | 
					
					
						
						| 
							 | 
						            np.save(str(path), x)   | 
					
					
						
						| 
							 | 
						            path.with_suffix('.cache.npy').rename(path)   | 
					
					
						
						| 
							 | 
						            LOGGER.info(f'{self.prefix}New cache created: {path}') | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.') | 
					
					
						
						| 
							 | 
						        return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_labels(self): | 
					
					
						
						| 
							 | 
						        """Returns dictionary of labels for YOLO training.""" | 
					
					
						
						| 
							 | 
						        self.label_files = img2label_paths(self.im_files) | 
					
					
						
						| 
							 | 
						        cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						            import gc | 
					
					
						
						| 
							 | 
						            gc.disable()   | 
					
					
						
						| 
							 | 
						            cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True   | 
					
					
						
						| 
							 | 
						            gc.enable() | 
					
					
						
						| 
							 | 
						            assert cache['version'] == self.cache_version   | 
					
					
						
						| 
							 | 
						            assert cache['hash'] == get_hash(self.label_files + self.im_files)   | 
					
					
						
						| 
							 | 
						        except (FileNotFoundError, AssertionError, AttributeError): | 
					
					
						
						| 
							 | 
						            cache, exists = self.cache_labels(cache_path), False   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        nf, nm, ne, nc, n = cache.pop('results')   | 
					
					
						
						| 
							 | 
						        if exists and LOCAL_RANK in (-1, 0): | 
					
					
						
						| 
							 | 
						            d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt' | 
					
					
						
						| 
							 | 
						            tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT)   | 
					
					
						
						| 
							 | 
						            if cache['msgs']: | 
					
					
						
						| 
							 | 
						                LOGGER.info('\n'.join(cache['msgs']))   | 
					
					
						
						| 
							 | 
						        if nf == 0:   | 
					
					
						
						| 
							 | 
						            raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        [cache.pop(k) for k in ('hash', 'version', 'msgs')]   | 
					
					
						
						| 
							 | 
						        labels = cache['labels'] | 
					
					
						
						| 
							 | 
						        self.im_files = [lb['im_file'] for lb in labels]   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels) | 
					
					
						
						| 
							 | 
						        len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths)) | 
					
					
						
						| 
							 | 
						        if len_segments and len_boxes != len_segments: | 
					
					
						
						| 
							 | 
						            LOGGER.warning( | 
					
					
						
						| 
							 | 
						                f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, ' | 
					
					
						
						| 
							 | 
						                f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. ' | 
					
					
						
						| 
							 | 
						                'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.') | 
					
					
						
						| 
							 | 
						            for lb in labels: | 
					
					
						
						| 
							 | 
						                lb['segments'] = [] | 
					
					
						
						| 
							 | 
						        if len_cls == 0: | 
					
					
						
						| 
							 | 
						            raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}') | 
					
					
						
						| 
							 | 
						        return labels | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def build_transforms(self, hyp=None): | 
					
					
						
						| 
							 | 
						        """Builds and appends transforms to the list.""" | 
					
					
						
						| 
							 | 
						        if self.augment: | 
					
					
						
						| 
							 | 
						            hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0 | 
					
					
						
						| 
							 | 
						            hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0 | 
					
					
						
						| 
							 | 
						            transforms = v8_transforms(self, self.imgsz, hyp) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)]) | 
					
					
						
						| 
							 | 
						        transforms.append( | 
					
					
						
						| 
							 | 
						            Format(bbox_format='xywh', | 
					
					
						
						| 
							 | 
						                   normalize=True, | 
					
					
						
						| 
							 | 
						                   return_mask=self.use_segments, | 
					
					
						
						| 
							 | 
						                   return_keypoint=self.use_keypoints, | 
					
					
						
						| 
							 | 
						                   batch_idx=True, | 
					
					
						
						| 
							 | 
						                   mask_ratio=hyp.mask_ratio, | 
					
					
						
						| 
							 | 
						                   mask_overlap=hyp.overlap_mask)) | 
					
					
						
						| 
							 | 
						        return transforms | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def close_mosaic(self, hyp): | 
					
					
						
						| 
							 | 
						        """Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.""" | 
					
					
						
						| 
							 | 
						        hyp.mosaic = 0.0   | 
					
					
						
						| 
							 | 
						        hyp.copy_paste = 0.0   | 
					
					
						
						| 
							 | 
						        hyp.mixup = 0.0   | 
					
					
						
						| 
							 | 
						        self.transforms = self.build_transforms(hyp) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def update_labels_info(self, label): | 
					
					
						
						| 
							 | 
						        """custom your label format here.""" | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        bboxes = label.pop('bboxes') | 
					
					
						
						| 
							 | 
						        segments = label.pop('segments') | 
					
					
						
						| 
							 | 
						        keypoints = label.pop('keypoints', None) | 
					
					
						
						| 
							 | 
						        bbox_format = label.pop('bbox_format') | 
					
					
						
						| 
							 | 
						        normalized = label.pop('normalized') | 
					
					
						
						| 
							 | 
						        label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) | 
					
					
						
						| 
							 | 
						        return label | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @staticmethod | 
					
					
						
						| 
							 | 
						    def collate_fn(batch): | 
					
					
						
						| 
							 | 
						        """Collates data samples into batches.""" | 
					
					
						
						| 
							 | 
						        new_batch = {} | 
					
					
						
						| 
							 | 
						        keys = batch[0].keys() | 
					
					
						
						| 
							 | 
						        values = list(zip(*[list(b.values()) for b in batch])) | 
					
					
						
						| 
							 | 
						        for i, k in enumerate(keys): | 
					
					
						
						| 
							 | 
						            value = values[i] | 
					
					
						
						| 
							 | 
						            if k == 'img': | 
					
					
						
						| 
							 | 
						                value = torch.stack(value, 0) | 
					
					
						
						| 
							 | 
						            if k in ['masks', 'keypoints', 'bboxes', 'cls']: | 
					
					
						
						| 
							 | 
						                value = torch.cat(value, 0) | 
					
					
						
						| 
							 | 
						            new_batch[k] = value | 
					
					
						
						| 
							 | 
						        new_batch['batch_idx'] = list(new_batch['batch_idx']) | 
					
					
						
						| 
							 | 
						        for i in range(len(new_batch['batch_idx'])): | 
					
					
						
						| 
							 | 
						            new_batch['batch_idx'][i] += i   | 
					
					
						
						| 
							 | 
						        new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0) | 
					
					
						
						| 
							 | 
						        return new_batch | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class ClassificationDataset(torchvision.datasets.ImageFolder): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    YOLO Classification Dataset. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        root (str): Dataset path. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Attributes: | 
					
					
						
						| 
							 | 
						        cache_ram (bool): True if images should be cached in RAM, False otherwise. | 
					
					
						
						| 
							 | 
						        cache_disk (bool): True if images should be cached on disk, False otherwise. | 
					
					
						
						| 
							 | 
						        samples (list): List of samples containing file, index, npy, and im. | 
					
					
						
						| 
							 | 
						        torch_transforms (callable): torchvision transforms applied to the dataset. | 
					
					
						
						| 
							 | 
						        album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, root, args, augment=False, cache=False): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Initialize YOLO object with root, image size, augmentations, and cache settings. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            root (str): Dataset path. | 
					
					
						
						| 
							 | 
						            args (Namespace): Argument parser containing dataset related settings. | 
					
					
						
						| 
							 | 
						            augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False. | 
					
					
						
						| 
							 | 
						            cache (bool | str | optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        super().__init__(root=root) | 
					
					
						
						| 
							 | 
						        if augment and args.fraction < 1.0:   | 
					
					
						
						| 
							 | 
						            self.samples = self.samples[:round(len(self.samples) * args.fraction)] | 
					
					
						
						| 
							 | 
						        self.cache_ram = cache is True or cache == 'ram' | 
					
					
						
						| 
							 | 
						        self.cache_disk = cache == 'disk' | 
					
					
						
						| 
							 | 
						        self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples]   | 
					
					
						
						| 
							 | 
						        self.torch_transforms = classify_transforms(args.imgsz) | 
					
					
						
						| 
							 | 
						        self.album_transforms = classify_albumentations( | 
					
					
						
						| 
							 | 
						            augment=augment, | 
					
					
						
						| 
							 | 
						            size=args.imgsz, | 
					
					
						
						| 
							 | 
						            scale=(1.0 - args.scale, 1.0),   | 
					
					
						
						| 
							 | 
						            hflip=args.fliplr, | 
					
					
						
						| 
							 | 
						            vflip=args.flipud, | 
					
					
						
						| 
							 | 
						            hsv_h=args.hsv_h,   | 
					
					
						
						| 
							 | 
						            hsv_s=args.hsv_s,   | 
					
					
						
						| 
							 | 
						            hsv_v=args.hsv_v,   | 
					
					
						
						| 
							 | 
						            mean=(0.0, 0.0, 0.0),   | 
					
					
						
						| 
							 | 
						            std=(1.0, 1.0, 1.0),   | 
					
					
						
						| 
							 | 
						            auto_aug=False) if augment else None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __getitem__(self, i): | 
					
					
						
						| 
							 | 
						        """Returns subset of data and targets corresponding to given indices.""" | 
					
					
						
						| 
							 | 
						        f, j, fn, im = self.samples[i]   | 
					
					
						
						| 
							 | 
						        if self.cache_ram and im is None: | 
					
					
						
						| 
							 | 
						            im = self.samples[i][3] = cv2.imread(f) | 
					
					
						
						| 
							 | 
						        elif self.cache_disk: | 
					
					
						
						| 
							 | 
						            if not fn.exists():   | 
					
					
						
						| 
							 | 
						                np.save(fn.as_posix(), cv2.imread(f)) | 
					
					
						
						| 
							 | 
						            im = np.load(fn) | 
					
					
						
						| 
							 | 
						        else:   | 
					
					
						
						| 
							 | 
						            im = cv2.imread(f)   | 
					
					
						
						| 
							 | 
						        if self.album_transforms: | 
					
					
						
						| 
							 | 
						            sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image'] | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            sample = self.torch_transforms(im) | 
					
					
						
						| 
							 | 
						        return {'img': sample, 'cls': j} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __len__(self) -> int: | 
					
					
						
						| 
							 | 
						        return len(self.samples) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class SemanticDataset(BaseDataset): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self): | 
					
					
						
						| 
							 | 
						        """Initialize a SemanticDataset object.""" | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						
 |