#!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright (c) Megvii, Inc. and its affiliates. import torch import torch.distributed as dist from yolox.utils import synchronize import random class DataPrefetcher: """ DataPrefetcher is inspired by code of following file: https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py It could speedup your pytorch dataloader. For more information, please check https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789. """ def __init__(self, loader): self.loader = iter(loader) self.stream = torch.cuda.Stream() self.input_cuda = self._input_cuda_for_image self.record_stream = DataPrefetcher._record_stream_for_image self.preload() def preload(self): try: self.next_input, self.next_target, _, _ = next(self.loader) except StopIteration: self.next_input = None self.next_target = None return with torch.cuda.stream(self.stream): self.input_cuda() self.next_target = self.next_target.cuda(non_blocking=True) def next(self): torch.cuda.current_stream().wait_stream(self.stream) input = self.next_input target = self.next_target if input is not None: self.record_stream(input) if target is not None: target.record_stream(torch.cuda.current_stream()) self.preload() return input, target def _input_cuda_for_image(self): self.next_input = self.next_input.cuda(non_blocking=True) @staticmethod def _record_stream_for_image(input): input.record_stream(torch.cuda.current_stream()) def random_resize(data_loader, exp, epoch, rank, is_distributed): tensor = torch.LongTensor(1).cuda() if is_distributed: synchronize() if rank == 0: if epoch > exp.max_epoch - 10: size = exp.input_size else: size = random.randint(*exp.random_size) size = int(32 * size) tensor.fill_(size) if is_distributed: synchronize() dist.broadcast(tensor, 0) input_size = data_loader.change_input_dim(multiple=tensor.item(), random_range=None) return input_size