#!/usr/bin/env python3 # -*- coding:utf-8 -*- # This code is based on # https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py import os from torch.utils.data import dataloader, distributed from .datasets import TrainValDataset from yolov6.utils.events import LOGGER from yolov6.utils.torch_utils import torch_distributed_zero_first def create_dataloader(path, img_size, batch_size, stride, hyp=None, augment=False, check_images=False, check_labels=False, pad=0.0, rect=False, rank=-1, workers=8, shuffle=False,class_names=None, task='Train'): '''Create general dataloader. Returns dataloader and dataset ''' if rect and shuffle: LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False') shuffle = False with torch_distributed_zero_first(rank): dataset = TrainValDataset(path, img_size, batch_size, augment=augment, hyp=hyp, rect=rect, check_images=check_images, stride=int(stride), pad=pad, rank=rank, class_names=class_names, task=task) batch_size = min(batch_size, len(dataset)) workers = min([os.cpu_count() // int(os.getenv('WORLD_SIZE', 1)), batch_size if batch_size > 1 else 0, workers]) # number of workers sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) return TrainValDataLoader(dataset, batch_size=batch_size, shuffle=shuffle and sampler is None, num_workers=workers, sampler=sampler, pin_memory=True, collate_fn=TrainValDataset.collate_fn), dataset class TrainValDataLoader(dataloader.DataLoader): """ Dataloader that reuses workers Uses same syntax as vanilla DataLoader """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) self.iterator = super().__iter__() def __len__(self): return len(self.batch_sampler.sampler) def __iter__(self): for i in range(len(self)): yield next(self.iterator) class _RepeatSampler: """ Sampler that repeats forever Args: sampler (Sampler) """ def __init__(self, sampler): self.sampler = sampler def __iter__(self): while True: yield from iter(self.sampler)