yolov6 / yolov6 /data /data_load.py
Theivaprakasham's picture
adding app
be49b0b
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# This code is based on
# https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py
import os
from torch.utils.data import dataloader, distributed
from .datasets import TrainValDataset
from yolov6.utils.events import LOGGER
from yolov6.utils.torch_utils import torch_distributed_zero_first
def create_dataloader(path, img_size, batch_size, stride, hyp=None, augment=False, check_images=False, check_labels=False, pad=0.0, rect=False, rank=-1, workers=8, shuffle=False,class_names=None, task='Train'):
'''Create general dataloader.
Returns dataloader and dataset
'''
if rect and shuffle:
LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
shuffle = False
with torch_distributed_zero_first(rank):
dataset = TrainValDataset(path, img_size, batch_size,
augment=augment,
hyp=hyp,
rect=rect,
check_images=check_images,
stride=int(stride),
pad=pad,
rank=rank,
class_names=class_names,
task=task)
batch_size = min(batch_size, len(dataset))
workers = min([os.cpu_count() // int(os.getenv('WORLD_SIZE', 1)), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
return TrainValDataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=workers,
sampler=sampler,
pin_memory=True,
collate_fn=TrainValDataset.collate_fn), dataset
class TrainValDataLoader(dataloader.DataLoader):
""" Dataloader that reuses workers
Uses same syntax as vanilla DataLoader
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler:
""" Sampler that repeats forever
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)