conex / espnet /utils /training /iterators.py
tobiasc's picture
Initial commit
ad16788
raw
history blame
No virus
3.59 kB
import chainer
from chainer.iterators import MultiprocessIterator
from chainer.iterators import SerialIterator
from chainer.iterators import ShuffleOrderSampler
from chainer.training.extension import Extension
import numpy as np
class ShufflingEnabler(Extension):
"""An extension enabling shuffling on an Iterator"""
def __init__(self, iterators):
"""Inits the ShufflingEnabler
:param list[Iterator] iterators: The iterators to enable shuffling on
"""
self.set = False
self.iterators = iterators
def __call__(self, trainer):
"""Calls the enabler on the given iterator
:param trainer: The iterator
"""
if not self.set:
for iterator in self.iterators:
iterator.start_shuffle()
self.set = True
class ToggleableShufflingSerialIterator(SerialIterator):
"""A SerialIterator having its shuffling property activated during training"""
def __init__(self, dataset, batch_size, repeat=True, shuffle=True):
"""Init the Iterator
:param torch.nn.Tensor dataset: The dataset to take batches from
:param int batch_size: The batch size
:param bool repeat: Whether to repeat data (allow multiple epochs)
:param bool shuffle: Whether to shuffle the batches
"""
super(ToggleableShufflingSerialIterator, self).__init__(
dataset, batch_size, repeat, shuffle
)
def start_shuffle(self):
"""Starts shuffling (or reshuffles) the batches"""
self._shuffle = True
if int(chainer._version.__version__[0]) <= 4:
self._order = np.random.permutation(len(self.dataset))
else:
self.order_sampler = ShuffleOrderSampler()
self._order = self.order_sampler(np.arange(len(self.dataset)), 0)
class ToggleableShufflingMultiprocessIterator(MultiprocessIterator):
"""A MultiprocessIterator having its shuffling property activated during training"""
def __init__(
self,
dataset,
batch_size,
repeat=True,
shuffle=True,
n_processes=None,
n_prefetch=1,
shared_mem=None,
maxtasksperchild=20,
):
"""Init the iterator
:param torch.nn.Tensor dataset: The dataset to take batches from
:param int batch_size: The batch size
:param bool repeat: Whether to repeat batches or not (enables multiple epochs)
:param bool shuffle: Whether to shuffle the order of the batches
:param int n_processes: How many processes to use
:param int n_prefetch: The number of prefetch to use
:param int shared_mem: How many memory to share between processes
:param int maxtasksperchild: Maximum number of tasks per child
"""
super(ToggleableShufflingMultiprocessIterator, self).__init__(
dataset=dataset,
batch_size=batch_size,
repeat=repeat,
shuffle=shuffle,
n_processes=n_processes,
n_prefetch=n_prefetch,
shared_mem=shared_mem,
maxtasksperchild=maxtasksperchild,
)
def start_shuffle(self):
"""Starts shuffling (or reshuffles) the batches"""
self.shuffle = True
if int(chainer._version.__version__[0]) <= 4:
self._order = np.random.permutation(len(self.dataset))
else:
self.order_sampler = ShuffleOrderSampler()
self._order = self.order_sampler(np.arange(len(self.dataset)), 0)
self._set_prefetch_state()