Spaces:
Running
on
Zero
Running
on
Zero
| import queue as Queue | |
| import threading | |
| import torch | |
| from torch.utils.data import DataLoader | |
| class PrefetchGenerator(threading.Thread): | |
| """A general prefetch generator. | |
| Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch | |
| Args: | |
| generator: Python generator. | |
| num_prefetch_queue (int): Number of prefetch queue. | |
| """ | |
| def __init__(self, generator, num_prefetch_queue): | |
| threading.Thread.__init__(self) | |
| self.queue = Queue.Queue(num_prefetch_queue) | |
| self.generator = generator | |
| self.daemon = True | |
| self.start() | |
| def run(self): | |
| for item in self.generator: | |
| self.queue.put(item) | |
| self.queue.put(None) | |
| def __next__(self): | |
| next_item = self.queue.get() | |
| if next_item is None: | |
| raise StopIteration | |
| return next_item | |
| def __iter__(self): | |
| return self | |
| class PrefetchDataLoader(DataLoader): | |
| """Prefetch version of dataloader. | |
| Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# | |
| TODO: | |
| Need to test on single gpu and ddp (multi-gpu). There is a known issue in | |
| ddp. | |
| Args: | |
| num_prefetch_queue (int): Number of prefetch queue. | |
| kwargs (dict): Other arguments for dataloader. | |
| """ | |
| def __init__(self, num_prefetch_queue, **kwargs): | |
| self.num_prefetch_queue = num_prefetch_queue | |
| super(PrefetchDataLoader, self).__init__(**kwargs) | |
| def __iter__(self): | |
| return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) | |
| class CPUPrefetcher(): | |
| """CPU prefetcher. | |
| Args: | |
| loader: Dataloader. | |
| """ | |
| def __init__(self, loader): | |
| self.ori_loader = loader | |
| self.loader = iter(loader) | |
| def next(self): | |
| try: | |
| return next(self.loader) | |
| except StopIteration: | |
| return None | |
| def reset(self): | |
| self.loader = iter(self.ori_loader) | |
| class CUDAPrefetcher(): | |
| """CUDA prefetcher. | |
| Reference: https://github.com/NVIDIA/apex/issues/304# | |
| It may consume more GPU memory. | |
| Args: | |
| loader: Dataloader. | |
| opt (dict): Options. | |
| """ | |
| def __init__(self, loader, opt): | |
| self.ori_loader = loader | |
| self.loader = iter(loader) | |
| self.opt = opt | |
| self.stream = torch.cuda.Stream() | |
| self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') | |
| self.preload() | |
| def preload(self): | |
| try: | |
| self.batch = next(self.loader) # self.batch is a dict | |
| except StopIteration: | |
| self.batch = None | |
| return None | |
| # put tensors to gpu | |
| with torch.cuda.stream(self.stream): | |
| for k, v in self.batch.items(): | |
| if torch.is_tensor(v): | |
| self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) | |
| def next(self): | |
| torch.cuda.current_stream().wait_stream(self.stream) | |
| batch = self.batch | |
| self.preload() | |
| return batch | |
| def reset(self): | |
| self.loader = iter(self.ori_loader) | |
| self.preload() | |