Spaces:
Sleeping
Sleeping
| from torch.utils.data import DataLoader | |
| class InfiniteLoader(DataLoader): | |
| def __init__( | |
| self, | |
| *args, | |
| num_workers=0, | |
| pin_memory=True, | |
| is_infinite = True, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| *args, | |
| multiprocessing_context="fork" if num_workers > 0 else None, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| **kwargs, | |
| ) | |
| self.dataset_iterator = super().__iter__() | |
| self.is_infinite = is_infinite | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| try: | |
| x = next(self.dataset_iterator) | |
| except StopIteration: | |
| self.dataset_iterator = super().__iter__() | |
| if self.is_infinite: | |
| x = next(self.dataset_iterator) | |
| else: | |
| raise StopIteration | |
| return x | |