File size: 4,092 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# domainbed/lib/fast_data_loader.py
import torch
from .datasets.ab_dataset import ABDataset
class _InfiniteSampler(torch.utils.data.Sampler):
"""Wraps another Sampler to yield an infinite stream."""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
for batch in self.sampler:
yield batch
class InfiniteDataLoader:
def __init__(self, dataset, weights, batch_size, num_workers, collate_fn=None):
super().__init__()
if weights:
sampler = torch.utils.data.WeightedRandomSampler(
weights, replacement=True, num_samples=batch_size
)
else:
sampler = torch.utils.data.RandomSampler(dataset, replacement=True)
batch_sampler = torch.utils.data.BatchSampler(
sampler, batch_size=batch_size, drop_last=True
)
if collate_fn is not None:
self._infinite_iterator = iter(
torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=_InfiniteSampler(batch_sampler),
pin_memory=False,
collate_fn=collate_fn
)
)
else:
self._infinite_iterator = iter(
torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=_InfiniteSampler(batch_sampler),
pin_memory=False
)
)
self.dataset = dataset
def __iter__(self):
while True:
yield next(self._infinite_iterator)
def __len__(self):
raise ValueError
class FastDataLoader:
"""
DataLoader wrapper with slightly improved speed by not respawning worker
processes at every epoch.
"""
def __init__(self, dataset, batch_size, num_workers, shuffle=False, collate_fn=None):
super().__init__()
self.num_workers = num_workers
if shuffle:
sampler = torch.utils.data.RandomSampler(dataset, replacement=False)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = torch.utils.data.BatchSampler(
sampler,
batch_size=batch_size,
drop_last=False,
)
if collate_fn is not None:
self._infinite_iterator = iter(
torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=_InfiniteSampler(batch_sampler),
pin_memory=False,
collate_fn=collate_fn
)
)
else:
self._infinite_iterator = iter(
torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=_InfiniteSampler(batch_sampler),
pin_memory=False,
)
)
self.dataset = dataset
self.batch_size = batch_size
self._length = len(batch_sampler)
def __iter__(self):
for _ in range(len(self)):
yield next(self._infinite_iterator)
def __len__(self):
return self._length
def build_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool, collate_fn=None):
assert batch_size <= len(dataset), len(dataset)
if infinite:
dataloader = InfiniteDataLoader(
dataset, None, batch_size, num_workers=num_workers, collate_fn=collate_fn)
else:
dataloader = FastDataLoader(
dataset, batch_size, num_workers, shuffle=shuffle_when_finite, collate_fn=collate_fn)
return dataloader
def get_a_batch_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool):
pass
|