|
import time |
|
|
|
from typing import Any, List |
|
|
|
import torch.utils.data.backward_compatibility |
|
|
|
import torch.utils.data.graph_settings |
|
from torch.utils.data import DataLoader, IterDataPipe, communication |
|
from torch.utils.data.datapipes.iter import IterableWrapper |
|
|
|
__all__ = [ |
|
"DataLoader2", |
|
] |
|
|
|
|
|
class _ThreadingDataLoader2: |
|
|
|
def __init__(self, datapipe, num_workers=0, collate_fn=None): |
|
self.threads = [] |
|
self.datapipes = [] |
|
self.collate_fn = collate_fn |
|
for worker_id in range(num_workers): |
|
(thread, req_queue, res_queue, thread_localdatapipe) = communication.eventloop.SpawnThreadForDataPipeline(datapipe) |
|
torch.utils.data.graph_settings.apply_sharding(thread_localdatapipe, num_workers, worker_id) |
|
thread.start() |
|
self.threads.append((thread, req_queue, res_queue)) |
|
local_datapipe = communication.iter.QueueWrapper( |
|
communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) |
|
self.datapipes.append(local_datapipe) |
|
|
|
def __iter__(self): |
|
not_available = False |
|
forever = True |
|
exclude_datapipes: List[Any] = [] |
|
while len(exclude_datapipes) < len(self.datapipes): |
|
for dp in self.datapipes: |
|
if dp not in exclude_datapipes: |
|
try: |
|
value = dp.nonblocking_next() |
|
yield value |
|
except StopIteration: |
|
exclude_datapipes.append(dp) |
|
except communication.iter.NotAvailable: |
|
not_available = True |
|
if not_available: |
|
time.sleep(0.001) |
|
|
|
def __del__(self): |
|
self._cleanup_all_threads() |
|
|
|
def _cleanup_all_threads(self): |
|
def clean_me(thread, req_queue, res_queue): |
|
req_queue.put(communication.messages.TerminateRequest()) |
|
_ = res_queue.get() |
|
thread.join() |
|
|
|
for thread, req_queue, res_queue in self.threads: |
|
clean_me(thread, req_queue, res_queue) |
|
|
|
class DataLoader2: |
|
def __new__(cls, |
|
dataset, |
|
batch_size=1, |
|
shuffle=None, |
|
sampler=None, |
|
batch_sampler=None, |
|
num_workers=0, |
|
collate_fn=None, |
|
pin_memory=False, |
|
drop_last=False, |
|
timeout=0, |
|
worker_init_fn=None, |
|
*, |
|
prefetch_factor=2, |
|
persistent_workers=False, |
|
batch_outside_worker=False, |
|
parallelism_mode='mp'): |
|
if isinstance(dataset, IterDataPipe): |
|
data_loader: Any = None |
|
if batch_sampler is not None: |
|
raise Exception( |
|
'batch_sampler is not yet supported by DataPipes') |
|
if sampler is not None: |
|
raise Exception( |
|
'sampler is not yet supported by DataPipes') |
|
datapipe = dataset |
|
datapipe = torch.utils.data.graph_settings.apply_shuffle_settings(datapipe, shuffle=shuffle) |
|
if batch_outside_worker and pin_memory: |
|
raise Exception( |
|
'pin_memory is not yet compatible with batch_outside_worker') |
|
if not batch_outside_worker: |
|
if batch_size is not None: |
|
datapipe = datapipe.batch(batch_size, drop_last=drop_last) |
|
if collate_fn is None: |
|
collate_fn = torch.utils.data._utils.collate.default_collate |
|
|
|
|
|
|
|
data_loader = DataLoader(datapipe, |
|
batch_size=None, |
|
shuffle=shuffle, |
|
sampler=None, |
|
batch_sampler=None, |
|
num_workers=num_workers, |
|
collate_fn=collate_fn, |
|
pin_memory=pin_memory, |
|
drop_last=False, |
|
timeout=timeout, |
|
worker_init_fn=worker_init_fn, |
|
prefetch_factor=prefetch_factor, |
|
persistent_workers=persistent_workers) |
|
elif parallelism_mode == 'thread': |
|
if collate_fn is not None and not batch_outside_worker: |
|
datapipe = datapipe.map(collate_fn) |
|
if pin_memory: |
|
raise Exception( |
|
'pin_memory is not yet supported by DataPipes with Threading') |
|
if worker_init_fn is not None: |
|
raise Exception( |
|
'worker_init_fn is not yet supported by DataPipes with Threading') |
|
data_loader = _ThreadingDataLoader2(datapipe, |
|
num_workers=num_workers, |
|
collate_fn=collate_fn) |
|
else: |
|
raise Exception('Unsupported parallelism mode', parallelism_mode) |
|
if not batch_outside_worker: |
|
return data_loader |
|
else: |
|
if collate_fn is None: |
|
collate_fn = torch.utils.data._utils.collate.default_collate |
|
datapipe = IterableWrapper(data_loader).batch( |
|
batch_size, drop_last=drop_last).map(collate_fn) |
|
return datapipe |
|
else: |
|
if parallelism_mode == 'thread': |
|
raise Exception( |
|
'thread parallelism mode is not supported for old DataSets') |
|
return DataLoader(dataset, |
|
batch_size=batch_size, |
|
shuffle=shuffle, |
|
sampler=sampler, |
|
batch_sampler=batch_sampler, |
|
num_workers=num_workers, |
|
collate_fn=collate_fn, |
|
pin_memory=pin_memory, |
|
drop_last=drop_last, |
|
timeout=timeout, |
|
worker_init_fn=worker_init_fn, |
|
prefetch_factor=prefetch_factor, |
|
persistent_workers=persistent_workers) |
|
|