Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# Copyright (c) Megvii, Inc. and its affiliates. | |
import torch | |
import torch.distributed as dist | |
from yolox.utils import synchronize | |
import random | |
class DataPrefetcher: | |
""" | |
DataPrefetcher is inspired by code of following file: | |
https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py | |
It could speedup your pytorch dataloader. For more information, please check | |
https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789. | |
""" | |
def __init__(self, loader): | |
self.loader = iter(loader) | |
self.stream = torch.cuda.Stream() | |
self.input_cuda = self._input_cuda_for_image | |
self.record_stream = DataPrefetcher._record_stream_for_image | |
self.preload() | |
def preload(self): | |
try: | |
self.next_input, self.next_target, _, _ = next(self.loader) | |
except StopIteration: | |
self.next_input = None | |
self.next_target = None | |
return | |
with torch.cuda.stream(self.stream): | |
self.input_cuda() | |
self.next_target = self.next_target.cuda(non_blocking=True) | |
def next(self): | |
torch.cuda.current_stream().wait_stream(self.stream) | |
input = self.next_input | |
target = self.next_target | |
if input is not None: | |
self.record_stream(input) | |
if target is not None: | |
target.record_stream(torch.cuda.current_stream()) | |
self.preload() | |
return input, target | |
def _input_cuda_for_image(self): | |
self.next_input = self.next_input.cuda(non_blocking=True) | |
def _record_stream_for_image(input): | |
input.record_stream(torch.cuda.current_stream()) | |
def random_resize(data_loader, exp, epoch, rank, is_distributed): | |
tensor = torch.LongTensor(1).cuda() | |
if is_distributed: | |
synchronize() | |
if rank == 0: | |
if epoch > exp.max_epoch - 10: | |
size = exp.input_size | |
else: | |
size = random.randint(*exp.random_size) | |
size = int(32 * size) | |
tensor.fill_(size) | |
if is_distributed: | |
synchronize() | |
dist.broadcast(tensor, 0) | |
input_size = data_loader.change_input_dim(multiple=tensor.item(), random_range=None) | |
return input_size | |