Spaces:
Sleeping
Sleeping
import torch | |
import torch.multiprocessing as multiprocessing | |
from torch._C import _set_worker_signal_handlers, \ | |
_remove_worker_pids, _error_if_any_worker_fails | |
try: | |
from torch._C import _set_worker_pids | |
except: | |
from torch._C import _update_worker_pids as _set_worker_pids | |
from .sampler import SequentialSampler, RandomSampler, BatchSampler | |
import signal | |
import collections | |
import re | |
import sys | |
import threading | |
import traceback | |
from torch._six import string_classes, int_classes | |
import numpy as np | |
if sys.version_info[0] == 2: | |
import Queue as queue | |
else: | |
import queue | |
class ExceptionWrapper(object): | |
r"Wraps an exception plus traceback to communicate across threads" | |
def __init__(self, exc_info): | |
self.exc_type = exc_info[0] | |
self.exc_msg = "".join(traceback.format_exception(*exc_info)) | |
_use_shared_memory = False | |
"""Whether to use shared memory in default_collate""" | |
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): | |
global _use_shared_memory | |
_use_shared_memory = True | |
# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal | |
# module's handlers are executed after Python returns from C low-level | |
# handlers, likely when the same fatal signal happened again already. | |
# https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 | |
_set_worker_signal_handlers() | |
torch.set_num_threads(1) | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
if init_fn is not None: | |
init_fn(worker_id) | |
while True: | |
r = index_queue.get() | |
if r is None: | |
break | |
idx, batch_indices = r | |
try: | |
samples = collate_fn([dataset[i] for i in batch_indices]) | |
except Exception: | |
data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) | |
else: | |
data_queue.put((idx, samples)) | |
def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): | |
if pin_memory: | |
torch.cuda.set_device(device_id) | |
while True: | |
try: | |
r = in_queue.get() | |
except Exception: | |
if done_event.is_set(): | |
return | |
raise | |
if r is None: | |
break | |
if isinstance(r[1], ExceptionWrapper): | |
out_queue.put(r) | |
continue | |
idx, batch = r | |
try: | |
if pin_memory: | |
batch = pin_memory_batch(batch) | |
except Exception: | |
out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) | |
else: | |
out_queue.put((idx, batch)) | |
numpy_type_map = { | |
'float64': torch.DoubleTensor, | |
'float32': torch.FloatTensor, | |
'float16': torch.HalfTensor, | |
'int64': torch.LongTensor, | |
'int32': torch.IntTensor, | |
'int16': torch.ShortTensor, | |
'int8': torch.CharTensor, | |
'uint8': torch.ByteTensor, | |
} | |
def default_collate(batch): | |
"Puts each data field into a tensor with outer dimension batch size" | |
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" | |
elem_type = type(batch[0]) | |
if torch.is_tensor(batch[0]): | |
out = None | |
if _use_shared_memory: | |
# If we're in a background process, concatenate directly into a | |
# shared memory tensor to avoid an extra copy | |
numel = sum([x.numel() for x in batch]) | |
storage = batch[0].storage()._new_shared(numel) | |
out = batch[0].new(storage) | |
return torch.stack(batch, 0, out=out) | |
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | |
and elem_type.__name__ != 'string_': | |
elem = batch[0] | |
if elem_type.__name__ == 'ndarray': | |
# array of string classes and object | |
if re.search('[SaUO]', elem.dtype.str) is not None: | |
raise TypeError(error_msg.format(elem.dtype)) | |
return torch.stack([torch.from_numpy(b) for b in batch], 0) | |
if elem.shape == (): # scalars | |
py_type = float if elem.dtype.name.startswith('float') else int | |
return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) | |
elif isinstance(batch[0], int_classes): | |
return torch.LongTensor(batch) | |
elif isinstance(batch[0], float): | |
return torch.DoubleTensor(batch) | |
elif isinstance(batch[0], string_classes): | |
return batch | |
elif isinstance(batch[0], collections.Mapping): | |
return {key: default_collate([d[key] for d in batch]) for key in batch[0]} | |
elif isinstance(batch[0], collections.Sequence): | |
transposed = zip(*batch) | |
return [default_collate(samples) for samples in transposed] | |
raise TypeError((error_msg.format(type(batch[0])))) | |
def pin_memory_batch(batch): | |
if torch.is_tensor(batch): | |
return batch.pin_memory() | |
elif isinstance(batch, string_classes): | |
return batch | |
elif isinstance(batch, collections.Mapping): | |
return {k: pin_memory_batch(sample) for k, sample in batch.items()} | |
elif isinstance(batch, collections.Sequence): | |
return [pin_memory_batch(sample) for sample in batch] | |
else: | |
return batch | |
_SIGCHLD_handler_set = False | |
"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one | |
handler needs to be set for all DataLoaders in a process.""" | |
def _set_SIGCHLD_handler(): | |
# Windows doesn't support SIGCHLD handler | |
if sys.platform == 'win32': | |
return | |
# can't set signal in child threads | |
if not isinstance(threading.current_thread(), threading._MainThread): | |
return | |
global _SIGCHLD_handler_set | |
if _SIGCHLD_handler_set: | |
return | |
previous_handler = signal.getsignal(signal.SIGCHLD) | |
if not callable(previous_handler): | |
previous_handler = None | |
def handler(signum, frame): | |
# This following call uses `waitid` with WNOHANG from C side. Therefore, | |
# Python can still get and update the process status successfully. | |
_error_if_any_worker_fails() | |
if previous_handler is not None: | |
previous_handler(signum, frame) | |
signal.signal(signal.SIGCHLD, handler) | |
_SIGCHLD_handler_set = True | |
class DataLoaderIter(object): | |
"Iterates once over the DataLoader's dataset, as specified by the sampler" | |
def __init__(self, loader): | |
self.dataset = loader.dataset | |
self.collate_fn = loader.collate_fn | |
self.batch_sampler = loader.batch_sampler | |
self.num_workers = loader.num_workers | |
self.pin_memory = loader.pin_memory and torch.cuda.is_available() | |
self.timeout = loader.timeout | |
self.done_event = threading.Event() | |
self.sample_iter = iter(self.batch_sampler) | |
if self.num_workers > 0: | |
self.worker_init_fn = loader.worker_init_fn | |
self.index_queue = multiprocessing.SimpleQueue() | |
self.worker_result_queue = multiprocessing.SimpleQueue() | |
self.batches_outstanding = 0 | |
self.worker_pids_set = False | |
self.shutdown = False | |
self.send_idx = 0 | |
self.rcvd_idx = 0 | |
self.reorder_dict = {} | |
base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0] | |
self.workers = [ | |
multiprocessing.Process( | |
target=_worker_loop, | |
args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, | |
base_seed + i, self.worker_init_fn, i)) | |
for i in range(self.num_workers)] | |
if self.pin_memory or self.timeout > 0: | |
self.data_queue = queue.Queue() | |
if self.pin_memory: | |
maybe_device_id = torch.cuda.current_device() | |
else: | |
# do not initialize cuda context if not necessary | |
maybe_device_id = None | |
self.worker_manager_thread = threading.Thread( | |
target=_worker_manager_loop, | |
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, | |
maybe_device_id)) | |
self.worker_manager_thread.daemon = True | |
self.worker_manager_thread.start() | |
else: | |
self.data_queue = self.worker_result_queue | |
for w in self.workers: | |
w.daemon = True # ensure that the worker exits on process exit | |
w.start() | |
_set_worker_pids(id(self), tuple(w.pid for w in self.workers)) | |
_set_SIGCHLD_handler() | |
self.worker_pids_set = True | |
# prime the prefetch loop | |
for _ in range(2 * self.num_workers): | |
self._put_indices() | |
def __len__(self): | |
return len(self.batch_sampler) | |
def _get_batch(self): | |
if self.timeout > 0: | |
try: | |
return self.data_queue.get(timeout=self.timeout) | |
except queue.Empty: | |
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) | |
else: | |
return self.data_queue.get() | |
def __next__(self): | |
if self.num_workers == 0: # same-process loading | |
indices = next(self.sample_iter) # may raise StopIteration | |
batch = self.collate_fn([self.dataset[i] for i in indices]) | |
if self.pin_memory: | |
batch = pin_memory_batch(batch) | |
return batch | |
# check if the next sample has already been generated | |
if self.rcvd_idx in self.reorder_dict: | |
batch = self.reorder_dict.pop(self.rcvd_idx) | |
return self._process_next_batch(batch) | |
if self.batches_outstanding == 0: | |
self._shutdown_workers() | |
raise StopIteration | |
while True: | |
assert (not self.shutdown and self.batches_outstanding > 0) | |
idx, batch = self._get_batch() | |
self.batches_outstanding -= 1 | |
if idx != self.rcvd_idx: | |
# store out-of-order samples | |
self.reorder_dict[idx] = batch | |
continue | |
return self._process_next_batch(batch) | |
next = __next__ # Python 2 compatibility | |
def __iter__(self): | |
return self | |
def _put_indices(self): | |
assert self.batches_outstanding < 2 * self.num_workers | |
indices = next(self.sample_iter, None) | |
if indices is None: | |
return | |
self.index_queue.put((self.send_idx, indices)) | |
self.batches_outstanding += 1 | |
self.send_idx += 1 | |
def _process_next_batch(self, batch): | |
self.rcvd_idx += 1 | |
self._put_indices() | |
if isinstance(batch, ExceptionWrapper): | |
raise batch.exc_type(batch.exc_msg) | |
return batch | |
def __getstate__(self): | |
# TODO: add limited pickling support for sharing an iterator | |
# across multiple threads for HOGWILD. | |
# Probably the best way to do this is by moving the sample pushing | |
# to a separate thread and then just sharing the data queue | |
# but signalling the end is tricky without a non-blocking API | |
raise NotImplementedError("DataLoaderIterator cannot be pickled") | |
def _shutdown_workers(self): | |
try: | |
if not self.shutdown: | |
self.shutdown = True | |
self.done_event.set() | |
# if worker_manager_thread is waiting to put | |
while not self.data_queue.empty(): | |
self.data_queue.get() | |
for _ in self.workers: | |
self.index_queue.put(None) | |
# done_event should be sufficient to exit worker_manager_thread, | |
# but be safe here and put another None | |
self.worker_result_queue.put(None) | |
finally: | |
# removes pids no matter what | |
if self.worker_pids_set: | |
_remove_worker_pids(id(self)) | |
self.worker_pids_set = False | |
def __del__(self): | |
if self.num_workers > 0: | |
self._shutdown_workers() | |
class DataLoader(object): | |
""" | |
Data loader. Combines a dataset and a sampler, and provides | |
single- or multi-process iterators over the dataset. | |
Arguments: | |
dataset (Dataset): dataset from which to load the data. | |
batch_size (int, optional): how many samples per batch to load | |
(default: 1). | |
shuffle (bool, optional): set to ``True`` to have the data reshuffled | |
at every epoch (default: False). | |
sampler (Sampler, optional): defines the strategy to draw samples from | |
the dataset. If specified, ``shuffle`` must be False. | |
batch_sampler (Sampler, optional): like sampler, but returns a batch of | |
indices at a time. Mutually exclusive with batch_size, shuffle, | |
sampler, and drop_last. | |
num_workers (int, optional): how many subprocesses to use for data | |
loading. 0 means that the data will be loaded in the main process. | |
(default: 0) | |
collate_fn (callable, optional): merges a list of samples to form a mini-batch. | |
pin_memory (bool, optional): If ``True``, the data loader will copy tensors | |
into CUDA pinned memory before returning them. | |
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, | |
if the dataset size is not divisible by the batch size. If ``False`` and | |
the size of dataset is not divisible by the batch size, then the last batch | |
will be smaller. (default: False) | |
timeout (numeric, optional): if positive, the timeout value for collecting a batch | |
from workers. Should always be non-negative. (default: 0) | |
worker_init_fn (callable, optional): If not None, this will be called on each | |
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as | |
input, after seeding and before data loading. (default: None) | |
.. note:: By default, each worker will have its PyTorch seed set to | |
``base_seed + worker_id``, where ``base_seed`` is a long generated | |
by main process using its RNG. You may use ``torch.initial_seed()`` to access | |
this value in :attr:`worker_init_fn`, which can be used to set other seeds | |
(e.g. NumPy) before data loading. | |
.. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an | |
unpicklable object, e.g., a lambda function. | |
""" | |
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, | |
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, | |
timeout=0, worker_init_fn=None): | |
self.dataset = dataset | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.collate_fn = collate_fn | |
self.pin_memory = pin_memory | |
self.drop_last = drop_last | |
self.timeout = timeout | |
self.worker_init_fn = worker_init_fn | |
if timeout < 0: | |
raise ValueError('timeout option should be non-negative') | |
if batch_sampler is not None: | |
if batch_size > 1 or shuffle or sampler is not None or drop_last: | |
raise ValueError('batch_sampler is mutually exclusive with ' | |
'batch_size, shuffle, sampler, and drop_last') | |
if sampler is not None and shuffle: | |
raise ValueError('sampler is mutually exclusive with shuffle') | |
if self.num_workers < 0: | |
raise ValueError('num_workers cannot be negative; ' | |
'use num_workers=0 to disable multiprocessing.') | |
if batch_sampler is None: | |
if sampler is None: | |
if shuffle: | |
sampler = RandomSampler(dataset) | |
else: | |
sampler = SequentialSampler(dataset) | |
batch_sampler = BatchSampler(sampler, batch_size, drop_last) | |
self.sampler = sampler | |
self.batch_sampler = batch_sampler | |
def __iter__(self): | |
return DataLoaderIter(self) | |
def __len__(self): | |
return len(self.batch_sampler) | |