import time from typing import Any, List import torch.utils.data.backward_compatibility import torch.utils.data.graph_settings from torch.utils.data import DataLoader, IterDataPipe, communication from torch.utils.data.datapipes.iter import IterableWrapper __all__ = [ "DataLoader2", ] class _ThreadingDataLoader2: def __init__(self, datapipe, num_workers=0, collate_fn=None): self.threads = [] self.datapipes = [] self.collate_fn = collate_fn for worker_id in range(num_workers): (thread, req_queue, res_queue, thread_localdatapipe) = communication.eventloop.SpawnThreadForDataPipeline(datapipe) torch.utils.data.graph_settings.apply_sharding(thread_localdatapipe, num_workers, worker_id) thread.start() self.threads.append((thread, req_queue, res_queue)) # These queues are independent local_datapipe = communication.iter.QueueWrapper( communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) self.datapipes.append(local_datapipe) def __iter__(self): not_available = False forever = True exclude_datapipes: List[Any] = [] while len(exclude_datapipes) < len(self.datapipes): for dp in self.datapipes: if dp not in exclude_datapipes: try: value = dp.nonblocking_next() yield value except StopIteration: exclude_datapipes.append(dp) except communication.iter.NotAvailable: not_available = True if not_available: time.sleep(0.001) def __del__(self): self._cleanup_all_threads() def _cleanup_all_threads(self): def clean_me(thread, req_queue, res_queue): req_queue.put(communication.messages.TerminateRequest()) _ = res_queue.get() thread.join() for thread, req_queue, res_queue in self.threads: clean_me(thread, req_queue, res_queue) class DataLoader2: def __new__(cls, dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False, batch_outside_worker=False, parallelism_mode='mp'): if isinstance(dataset, IterDataPipe): data_loader: Any = None if batch_sampler is not None: raise Exception( 'batch_sampler is not yet supported by DataPipes') if sampler is not None: raise Exception( 'sampler is not yet supported by DataPipes') datapipe = dataset datapipe = torch.utils.data.graph_settings.apply_shuffle_settings(datapipe, shuffle=shuffle) # type: ignore[assignment] if batch_outside_worker and pin_memory: raise Exception( 'pin_memory is not yet compatible with batch_outside_worker') if not batch_outside_worker: if batch_size is not None: datapipe = datapipe.batch(batch_size, drop_last=drop_last) if collate_fn is None: collate_fn = torch.utils.data._utils.collate.default_collate # Note: It is safe to pass shuffle=True to the old DataLoader, as shuffle does nothing # for Iterable, but required to set Pipes correctly. data_loader = DataLoader(datapipe, batch_size=None, # Replaced by .batch DataPipe shuffle=shuffle, sampler=None, batch_sampler=None, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=False, # Replaced by .batch DataPipe timeout=timeout, worker_init_fn=worker_init_fn, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers) elif parallelism_mode == 'thread': if collate_fn is not None and not batch_outside_worker: datapipe = datapipe.map(collate_fn) if pin_memory: raise Exception( 'pin_memory is not yet supported by DataPipes with Threading') if worker_init_fn is not None: raise Exception( 'worker_init_fn is not yet supported by DataPipes with Threading') data_loader = _ThreadingDataLoader2(datapipe, num_workers=num_workers, collate_fn=collate_fn) else: raise Exception('Unsupported parallelism mode', parallelism_mode) if not batch_outside_worker: return data_loader else: if collate_fn is None: collate_fn = torch.utils.data._utils.collate.default_collate datapipe = IterableWrapper(data_loader).batch( batch_size, drop_last=drop_last).map(collate_fn) return datapipe else: if parallelism_mode == 'thread': raise Exception( 'thread parallelism mode is not supported for old DataSets') return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)