| | |
| | import numpy as np |
| | from torch.utils.data.sampler import BatchSampler, Sampler |
| |
|
| |
|
| | class GroupedBatchSampler(BatchSampler): |
| | """ |
| | Wraps another sampler to yield a mini-batch of indices. |
| | It enforces that the batch only contain elements from the same group. |
| | It also tries to provide mini-batches which follows an ordering which is |
| | as close as possible to the ordering from the original sampler. |
| | """ |
| |
|
| | def __init__(self, sampler, group_ids, batch_size): |
| | """ |
| | Args: |
| | sampler (Sampler): Base sampler. |
| | group_ids (list[int]): If the sampler produces indices in range [0, N), |
| | `group_ids` must be a list of `N` ints which contains the group id of each sample. |
| | The group ids must be a set of integers in the range [0, num_groups). |
| | batch_size (int): Size of mini-batch. |
| | """ |
| | if not isinstance(sampler, Sampler): |
| | raise ValueError( |
| | "sampler should be an instance of " |
| | "torch.utils.data.Sampler, but got sampler={}".format(sampler) |
| | ) |
| | self.sampler = sampler |
| | self.group_ids = np.asarray(group_ids) |
| | assert self.group_ids.ndim == 1 |
| | self.batch_size = batch_size |
| | groups = np.unique(self.group_ids).tolist() |
| |
|
| | |
| | self.buffer_per_group = {k: [] for k in groups} |
| |
|
| | def __iter__(self): |
| | for idx in self.sampler: |
| | group_id = self.group_ids[idx] |
| | group_buffer = self.buffer_per_group[group_id] |
| | group_buffer.append(idx) |
| | if len(group_buffer) == self.batch_size: |
| | yield group_buffer[:] |
| | del group_buffer[:] |
| |
|
| | def __len__(self): |
| | raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.") |
| |
|