import torch import torch.multiprocessing as multiprocessing from torch._C import (_error_if_any_worker_fails, _remove_worker_pids, _set_worker_signal_handlers) try: from torch._C import _set_worker_pids except: from torch._C import _update_worker_pids as _set_worker_pids import collections import re import signal import sys import threading import traceback import numpy as np from torch._six import int_classes, string_classes from .sampler import BatchSampler, RandomSampler, SequentialSampler if sys.version_info[0] == 2: import Queue as queue else: import queue class ExceptionWrapper(object): r"Wraps an exception plus traceback to communicate across threads" def __init__(self, exc_info): self.exc_type = exc_info[0] self.exc_msg = "".join(traceback.format_exception(*exc_info)) _use_shared_memory = False """Whether to use shared memory in default_collate""" def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): global _use_shared_memory _use_shared_memory = True # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal happened again already. # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 _set_worker_signal_handlers() torch.set_num_threads(1) torch.manual_seed(seed) np.random.seed(seed) if init_fn is not None: init_fn(worker_id) while True: r = index_queue.get() if r is None: break idx, batch_indices = r try: samples = collate_fn([dataset[i] for i in batch_indices]) except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): if pin_memory: torch.cuda.set_device(device_id) while True: try: r = in_queue.get() except Exception: if done_event.is_set(): return raise if r is None: break if isinstance(r[1], ExceptionWrapper): out_queue.put(r) continue idx, batch = r try: if pin_memory: batch = pin_memory_batch(batch) except Exception: out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: out_queue.put((idx, batch)) numpy_type_map = { 'float64': torch.DoubleTensor, 'float32': torch.FloatTensor, 'float16': torch.HalfTensor, 'int64': torch.LongTensor, 'int32': torch.IntTensor, 'int16': torch.ShortTensor, 'int8': torch.CharTensor, 'uint8': torch.ByteTensor, } def default_collate(batch): "Puts each data field into a tensor with outer dimension batch size" error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" elem_type = type(batch[0]) if torch.is_tensor(batch[0]): out = None if _use_shared_memory: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = batch[0].storage()._new_shared(numel) out = batch[0].new(storage) return torch.stack(batch, 0, out=out) elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': elem = batch[0] if elem_type.__name__ == 'ndarray': # array of string classes and object if re.search('[SaUO]', elem.dtype.str) is not None: raise TypeError(error_msg.format(elem.dtype)) return torch.stack([torch.from_numpy(b) for b in batch], 0) if elem.shape == (): # scalars py_type = float if elem.dtype.name.startswith('float') else int return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) elif isinstance(batch[0], int_classes): return torch.LongTensor(batch) elif isinstance(batch[0], float): return torch.DoubleTensor(batch) elif isinstance(batch[0], string_classes): return batch elif isinstance(batch[0], collections.Mapping): return {key: default_collate([d[key] for d in batch]) for key in batch[0]} elif isinstance(batch[0], collections.Sequence): transposed = zip(*batch) return [default_collate(samples) for samples in transposed] raise TypeError((error_msg.format(type(batch[0])))) def pin_memory_batch(batch): if torch.is_tensor(batch): return batch.pin_memory() elif isinstance(batch, string_classes): return batch elif isinstance(batch, collections.Mapping): return {k: pin_memory_batch(sample) for k, sample in batch.items()} elif isinstance(batch, collections.Sequence): return [pin_memory_batch(sample) for sample in batch] else: return batch _SIGCHLD_handler_set = False """Whether SIGCHLD handler is set for DataLoader worker failures. Only one handler needs to be set for all DataLoaders in a process.""" def _set_SIGCHLD_handler(): # Windows doesn't support SIGCHLD handler if sys.platform == 'win32': return # can't set signal in child threads if not isinstance(threading.current_thread(), threading._MainThread): return global _SIGCHLD_handler_set if _SIGCHLD_handler_set: return previous_handler = signal.getsignal(signal.SIGCHLD) if not callable(previous_handler): previous_handler = None def handler(signum, frame): # This following call uses `waitid` with WNOHANG from C side. Therefore, # Python can still get and update the process status successfully. _error_if_any_worker_fails() if previous_handler is not None: previous_handler(signum, frame) signal.signal(signal.SIGCHLD, handler) _SIGCHLD_handler_set = True class DataLoaderIter(object): "Iterates once over the DataLoader's dataset, as specified by the sampler" def __init__(self, loader): self.dataset = loader.dataset self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout self.done_event = threading.Event() self.sample_iter = iter(self.batch_sampler) if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn self.index_queue = multiprocessing.SimpleQueue() self.worker_result_queue = multiprocessing.SimpleQueue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0] self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i)) for i in range(self.num_workers)] if self.pin_memory or self.timeout > 0: self.data_queue = queue.Queue() if self.pin_memory: maybe_device_id = torch.cuda.current_device() else: # do not initialize cuda context if not necessary maybe_device_id = None self.worker_manager_thread = threading.Thread( target=_worker_manager_loop, args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, maybe_device_id)) self.worker_manager_thread.daemon = True self.worker_manager_thread.start() else: self.data_queue = self.worker_result_queue for w in self.workers: w.daemon = True # ensure that the worker exits on process exit w.start() _set_worker_pids(id(self), tuple(w.pid for w in self.workers)) _set_SIGCHLD_handler() self.worker_pids_set = True # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices() def __len__(self): return len(self.batch_sampler) def _get_batch(self): if self.timeout > 0: try: return self.data_queue.get(timeout=self.timeout) except queue.Empty: raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) else: return self.data_queue.get() def __next__(self): if self.num_workers == 0: # same-process loading indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = pin_memory_batch(batch) return batch # check if the next sample has already been generated if self.rcvd_idx in self.reorder_dict: batch = self.reorder_dict.pop(self.rcvd_idx) return self._process_next_batch(batch) if self.batches_outstanding == 0: self._shutdown_workers() raise StopIteration while True: assert (not self.shutdown and self.batches_outstanding > 0) idx, batch = self._get_batch() self.batches_outstanding -= 1 if idx != self.rcvd_idx: # store out-of-order samples self.reorder_dict[idx] = batch continue return self._process_next_batch(batch) next = __next__ # Python 2 compatibility def __iter__(self): return self def _put_indices(self): assert self.batches_outstanding < 2 * self.num_workers indices = next(self.sample_iter, None) if indices is None: return self.index_queue.put((self.send_idx, indices)) self.batches_outstanding += 1 self.send_idx += 1 def _process_next_batch(self, batch): self.rcvd_idx += 1 self._put_indices() if isinstance(batch, ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch def __getstate__(self): # TODO: add limited pickling support for sharing an iterator # across multiple threads for HOGWILD. # Probably the best way to do this is by moving the sample pushing # to a separate thread and then just sharing the data queue # but signalling the end is tricky without a non-blocking API raise NotImplementedError("DataLoaderIterator cannot be pickled") def _shutdown_workers(self): try: if not self.shutdown: self.shutdown = True self.done_event.set() # if worker_manager_thread is waiting to put while not self.data_queue.empty(): self.data_queue.get() for _ in self.workers: self.index_queue.put(None) # done_event should be sufficient to exit worker_manager_thread, # but be safe here and put another None self.worker_result_queue.put(None) finally: # removes pids no matter what if self.worker_pids_set: _remove_worker_pids(id(self)) self.worker_pids_set = False def __del__(self): if self.num_workers > 0: self._shutdown_workers() class DataLoader(object): """ Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset. Arguments: dataset (Dataset): dataset from which to load the data. batch_size (int, optional): how many samples per batch to load (default: 1). shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: False). sampler (Sampler, optional): defines the strategy to draw samples from the dataset. If specified, ``shuffle`` must be False. batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last. num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) collate_fn (callable, optional): merges a list of samples to form a mini-batch. pin_memory (bool, optional): If ``True``, the data loader will copy tensors into CUDA pinned memory before returning them. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False) timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0) worker_init_fn (callable, optional): If not None, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: None) .. note:: By default, each worker will have its PyTorch seed set to ``base_seed + worker_id``, where ``base_seed`` is a long generated by main process using its RNG. You may use ``torch.initial_seed()`` to access this value in :attr:`worker_init_fn`, which can be used to set other seeds (e.g. NumPy) before data loading. .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an unpicklable object, e.g., a lambda function. """ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory self.drop_last = drop_last self.timeout = timeout self.worker_init_fn = worker_init_fn if timeout < 0: raise ValueError('timeout option should be non-negative') if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler is mutually exclusive with ' 'batch_size, shuffle, sampler, and drop_last') if sampler is not None and shuffle: raise ValueError('sampler is mutually exclusive with shuffle') if self.num_workers < 0: raise ValueError('num_workers cannot be negative; ' 'use num_workers=0 to disable multiprocessing.') if batch_sampler is None: if sampler is None: if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler def __iter__(self): return DataLoaderIter(self) def __len__(self): return len(self.batch_sampler)