Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# This code is based on | |
# https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py | |
import os | |
from torch.utils.data import dataloader, distributed | |
from .datasets import TrainValDataset | |
from yolov6.utils.events import LOGGER | |
from yolov6.utils.torch_utils import torch_distributed_zero_first | |
def create_dataloader(path, img_size, batch_size, stride, hyp=None, augment=False, check_images=False, check_labels=False, pad=0.0, rect=False, rank=-1, workers=8, shuffle=False,class_names=None, task='Train'): | |
'''Create general dataloader. | |
Returns dataloader and dataset | |
''' | |
if rect and shuffle: | |
LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False') | |
shuffle = False | |
with torch_distributed_zero_first(rank): | |
dataset = TrainValDataset(path, img_size, batch_size, | |
augment=augment, | |
hyp=hyp, | |
rect=rect, | |
check_images=check_images, | |
stride=int(stride), | |
pad=pad, | |
rank=rank, | |
class_names=class_names, | |
task=task) | |
batch_size = min(batch_size, len(dataset)) | |
workers = min([os.cpu_count() // int(os.getenv('WORLD_SIZE', 1)), batch_size if batch_size > 1 else 0, workers]) # number of workers | |
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) | |
return TrainValDataLoader(dataset, | |
batch_size=batch_size, | |
shuffle=shuffle and sampler is None, | |
num_workers=workers, | |
sampler=sampler, | |
pin_memory=True, | |
collate_fn=TrainValDataset.collate_fn), dataset | |
class TrainValDataLoader(dataloader.DataLoader): | |
""" Dataloader that reuses workers | |
Uses same syntax as vanilla DataLoader | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) | |
self.iterator = super().__iter__() | |
def __len__(self): | |
return len(self.batch_sampler.sampler) | |
def __iter__(self): | |
for i in range(len(self)): | |
yield next(self.iterator) | |
class _RepeatSampler: | |
""" Sampler that repeats forever | |
Args: | |
sampler (Sampler) | |
""" | |
def __init__(self, sampler): | |
self.sampler = sampler | |
def __iter__(self): | |
while True: | |
yield from iter(self.sampler) | |