yolov6 / yolov6 /data /data_load.py
yourusername's picture
:beers: cheers
2a27594
raw history blame
No virus
2.78 kB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# This code is based on
# https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py
import os
from torch.utils.data import dataloader, distributed
from .datasets import TrainValDataset
from yolov6.utils.events import LOGGER
from yolov6.utils.torch_utils import torch_distributed_zero_first
def create_dataloader(
path,
img_size,
batch_size,
stride,
hyp=None,
augment=False,
check_images=False,
check_labels=False,
pad=0.0,
rect=False,
rank=-1,
workers=8,
shuffle=False,
data_dict=None,
task="Train",
):
"""Create general dataloader.
Returns dataloader and dataset
"""
if rect and shuffle:
LOGGER.warning(
"WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False"
)
shuffle = False
with torch_distributed_zero_first(rank):
dataset = TrainValDataset(
path,
img_size,
batch_size,
augment=augment,
hyp=hyp,
rect=rect,
check_images=check_images,
check_labels=check_labels,
stride=int(stride),
pad=pad,
rank=rank,
data_dict=data_dict,
task=task,
)
batch_size = min(batch_size, len(dataset))
workers = min(
[
os.cpu_count() // int(os.getenv("WORLD_SIZE", 1)),
batch_size if batch_size > 1 else 0,
workers,
]
) # number of workers
sampler = (
None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
)
return (
TrainValDataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=workers,
sampler=sampler,
pin_memory=True,
collate_fn=TrainValDataset.collate_fn,
),
dataset,
)
class TrainValDataLoader(dataloader.DataLoader):
"""Dataloader that reuses workers
Uses same syntax as vanilla DataLoader
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler:
"""Sampler that repeats forever
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)