| import queue as Queue |
| import threading |
| import torch |
| from torch.utils.data import DataLoader |
|
|
|
|
| class PrefetchGenerator(threading.Thread): |
| """A general prefetch generator. |
| |
| Ref: |
| https://stackoverflow.com/questions/7323664/python-generator-pre-fetch |
| |
| Args: |
| generator: Python generator. |
| num_prefetch_queue (int): Number of prefetch queue. |
| """ |
|
|
| def __init__(self, generator, num_prefetch_queue): |
| threading.Thread.__init__(self) |
| self.queue = Queue.Queue(num_prefetch_queue) |
| self.generator = generator |
| self.daemon = True |
| self.start() |
|
|
| def run(self): |
| for item in self.generator: |
| self.queue.put(item) |
| self.queue.put(None) |
|
|
| def __next__(self): |
| next_item = self.queue.get() |
| if next_item is None: |
| raise StopIteration |
| return next_item |
|
|
| def __iter__(self): |
| return self |
|
|
|
|
| class PrefetchDataLoader(DataLoader): |
| """Prefetch version of dataloader. |
| |
| Ref: |
| https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# |
| |
| TODO: |
| Need to test on single gpu and ddp (multi-gpu). There is a known issue in |
| ddp. |
| |
| Args: |
| num_prefetch_queue (int): Number of prefetch queue. |
| kwargs (dict): Other arguments for dataloader. |
| """ |
|
|
| def __init__(self, num_prefetch_queue, **kwargs): |
| self.num_prefetch_queue = num_prefetch_queue |
| super(PrefetchDataLoader, self).__init__(**kwargs) |
|
|
| def __iter__(self): |
| return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) |
|
|
|
|
| class CPUPrefetcher(): |
| """CPU prefetcher. |
| |
| Args: |
| loader: Dataloader. |
| """ |
|
|
| def __init__(self, loader): |
| self.ori_loader = loader |
| self.loader = iter(loader) |
|
|
| def next(self): |
| try: |
| return next(self.loader) |
| except StopIteration: |
| return None |
|
|
| def reset(self): |
| self.loader = iter(self.ori_loader) |
|
|
|
|
| class CUDAPrefetcher(): |
| """CUDA prefetcher. |
| |
| Ref: |
| https://github.com/NVIDIA/apex/issues/304# |
| |
| It may consums more GPU memory. |
| |
| Args: |
| loader: Dataloader. |
| opt (dict): Options. |
| """ |
|
|
| def __init__(self, loader, opt): |
| self.ori_loader = loader |
| self.loader = iter(loader) |
| self.opt = opt |
| self.stream = torch.cuda.Stream() |
| self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') |
| self.preload() |
|
|
| def preload(self): |
| try: |
| self.batch = next(self.loader) |
| except StopIteration: |
| self.batch = None |
| return None |
| |
| with torch.cuda.stream(self.stream): |
| for k, v in self.batch.items(): |
| if torch.is_tensor(v): |
| self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) |
|
|
| def next(self): |
| torch.cuda.current_stream().wait_stream(self.stream) |
| batch = self.batch |
| self.preload() |
| return batch |
|
|
| def reset(self): |
| self.loader = iter(self.ori_loader) |
| self.preload() |
|
|