|
|
|
|
|
|
|
|
|
|
|
|
|
"""Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular""" |
|
|
|
from typing import Iterable |
|
|
|
import torch |
|
from torch.utils.data import ( |
|
ConcatDataset as TorchConcatDataset, |
|
Dataset, |
|
Subset as TorchSubset, |
|
) |
|
|
|
|
|
class ConcatDataset(TorchConcatDataset): |
|
def __init__(self, datasets: Iterable[Dataset]) -> None: |
|
super(ConcatDataset, self).__init__(datasets) |
|
|
|
self.repeat_factors = torch.cat([d.repeat_factors for d in datasets]) |
|
|
|
def set_epoch(self, epoch: int): |
|
for dataset in self.datasets: |
|
if hasattr(dataset, "epoch"): |
|
dataset.epoch = epoch |
|
if hasattr(dataset, "set_epoch"): |
|
dataset.set_epoch(epoch) |
|
|
|
|
|
class Subset(TorchSubset): |
|
def __init__(self, dataset, indices) -> None: |
|
super(Subset, self).__init__(dataset, indices) |
|
|
|
self.repeat_factors = dataset.repeat_factors[indices] |
|
assert len(indices) == len(self.repeat_factors) |
|
|
|
|
|
|
|
class RepeatFactorWrapper(Dataset): |
|
""" |
|
Thin wrapper around a dataset to implement repeat factor sampling. |
|
The underlying dataset must have a repeat_factors member to indicate the per-image factor. |
|
Set it to uniformly ones to disable repeat factor sampling |
|
""" |
|
|
|
def __init__(self, dataset, seed: int = 0): |
|
self.dataset = dataset |
|
self.epoch_ids = None |
|
self._seed = seed |
|
|
|
|
|
self._int_part = torch.trunc(dataset.repeat_factors) |
|
self._frac_part = dataset.repeat_factors - self._int_part |
|
|
|
def _get_epoch_indices(self, generator): |
|
""" |
|
Create a list of dataset indices (with repeats) to use for one epoch. |
|
|
|
Args: |
|
generator (torch.Generator): pseudo random number generator used for |
|
stochastic rounding. |
|
|
|
Returns: |
|
torch.Tensor: list of dataset indices to use in one epoch. Each index |
|
is repeated based on its calculated repeat factor. |
|
""" |
|
|
|
|
|
|
|
rands = torch.rand(len(self._frac_part), generator=generator) |
|
rep_factors = self._int_part + (rands < self._frac_part).float() |
|
|
|
indices = [] |
|
for dataset_index, rep_factor in enumerate(rep_factors): |
|
indices.extend([dataset_index] * int(rep_factor.item())) |
|
return torch.tensor(indices, dtype=torch.int64) |
|
|
|
def __len__(self): |
|
if self.epoch_ids is None: |
|
|
|
|
|
|
|
raise RuntimeError("please call set_epoch first to get wrapped length") |
|
|
|
|
|
return len(self.epoch_ids) |
|
|
|
def set_epoch(self, epoch: int): |
|
g = torch.Generator() |
|
g.manual_seed(self._seed + epoch) |
|
self.epoch_ids = self._get_epoch_indices(g) |
|
if hasattr(self.dataset, "set_epoch"): |
|
self.dataset.set_epoch(epoch) |
|
|
|
def __getitem__(self, idx): |
|
if self.epoch_ids is None: |
|
raise RuntimeError( |
|
"Repeat ids haven't been computed. Did you forget to call set_epoch?" |
|
) |
|
|
|
return self.dataset[self.epoch_ids[idx]] |
|
|