File size: 4,092 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# domainbed/lib/fast_data_loader.py

import torch
from .datasets.ab_dataset import ABDataset


class _InfiniteSampler(torch.utils.data.Sampler):
    """Wraps another Sampler to yield an infinite stream."""

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            for batch in self.sampler:
                yield batch


class InfiniteDataLoader:
    def __init__(self, dataset, weights, batch_size, num_workers, collate_fn=None):
        super().__init__()

        if weights:
            sampler = torch.utils.data.WeightedRandomSampler(
                weights, replacement=True, num_samples=batch_size
            )
        else:
            sampler = torch.utils.data.RandomSampler(dataset, replacement=True)

        batch_sampler = torch.utils.data.BatchSampler(
            sampler, batch_size=batch_size, drop_last=True
        )

        if collate_fn is not None:
            self._infinite_iterator = iter(
                torch.utils.data.DataLoader(
                    dataset,
                    num_workers=num_workers,
                    batch_sampler=_InfiniteSampler(batch_sampler),
                    pin_memory=False,
                    collate_fn=collate_fn
                )
            )
        else:
            self._infinite_iterator = iter(
                torch.utils.data.DataLoader(
                    dataset,
                    num_workers=num_workers,
                    batch_sampler=_InfiniteSampler(batch_sampler),
                    pin_memory=False
                )
            )
        self.dataset = dataset

    def __iter__(self):
        while True:
            yield next(self._infinite_iterator)

    def __len__(self):
        raise ValueError


class FastDataLoader:
    """
    DataLoader wrapper with slightly improved speed by not respawning worker
    processes at every epoch.
    """

    def __init__(self, dataset, batch_size, num_workers, shuffle=False, collate_fn=None):
        super().__init__()
        
        self.num_workers = num_workers

        if shuffle:
            sampler = torch.utils.data.RandomSampler(dataset, replacement=False)
        else:
            sampler = torch.utils.data.SequentialSampler(dataset)

        batch_sampler = torch.utils.data.BatchSampler(
            sampler,
            batch_size=batch_size,
            drop_last=False,
        )
        if collate_fn is not None:
            self._infinite_iterator = iter(
                torch.utils.data.DataLoader(
                    dataset,
                    num_workers=num_workers,
                    batch_sampler=_InfiniteSampler(batch_sampler),
                    pin_memory=False,
                    collate_fn=collate_fn
                )
            )
        else:
            self._infinite_iterator = iter(
                torch.utils.data.DataLoader(
                    dataset,
                    num_workers=num_workers,
                    batch_sampler=_InfiniteSampler(batch_sampler),
                    pin_memory=False,
                )
            )

        self.dataset = dataset
        self.batch_size = batch_size
        self._length = len(batch_sampler)

    def __iter__(self):
        for _ in range(len(self)):
            yield next(self._infinite_iterator)

    def __len__(self):
        return self._length


def build_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool, collate_fn=None):
    assert batch_size <= len(dataset), len(dataset)
    if infinite:
        dataloader = InfiniteDataLoader(
            dataset, None, batch_size, num_workers=num_workers, collate_fn=collate_fn)
    else:
        dataloader = FastDataLoader(
            dataset, batch_size, num_workers, shuffle=shuffle_when_finite, collate_fn=collate_fn)

    return dataloader


def get_a_batch_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool):
    pass