File size: 2,167 Bytes
373af33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from typing import Iterator, List

from torch.utils.data import BatchSampler, Sampler


class MonoTaskBatchSampler(BatchSampler):

    def __init__(self,
                 sampler: Sampler,
                 batch_size: int,
                 num_tasks: int,
                 drop_last: bool = False) -> None:
        if not isinstance(sampler, Sampler):
            raise TypeError('sampler should be an instance of ``Sampler``, '
                            f'but got {sampler}')
        if not isinstance(batch_size, int) or batch_size <= 0:
            raise ValueError('batch_size should be a positive integer value, '
                             f'but got batch_size={batch_size}')
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last
        self._task_buckets = [[] for _ in range(num_tasks)]
        self.num_tasks = num_tasks

    def __iter__(self) -> Iterator[List[int]]:
        for idx in self.sampler:
            bucket_id = self.sampler.dataset.get_task_idx(idx)
            bucket = self._task_buckets[bucket_id]
            bucket.append(idx)
            # yield a batch of indices in the same aspect ratio group
            if len(bucket) == self.batch_size:
                yield bucket[:]
                del bucket[:]

        # yield the rest data and reset the bucket
        left_data = []
        for i in range(self.num_tasks):
            if len(self._task_buckets[i]) > 0:
                left_data.append(self._task_buckets[i])

        self._task_buckets = [[] for _ in range(self.num_tasks)]
        for data in left_data:
            yield data
        # while len(left_data) > 0:
        #     if len(left_data) <= self.batch_size:
        #         if not self.drop_last:
        #             yield left_data[:]
        #         left_data = []
        #     else:
        #         yield left_data[:self.batch_size]
        #         left_data = left_data[self.batch_size:]

    def __len__(self) -> int:
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size