poiqazwsx's picture
Upload 57 files
51e2f90
raw
history blame
1.06 kB
import warnings
from typing import Dict, Optional, Union
import torch
from torch import nn
from torch.utils import data
class AugmentedDataset(data.Dataset):
def __init__(
self,
dataset: data.Dataset,
augmentation: nn.Module = nn.Identity(),
target_length: Optional[int] = None,
) -> None:
warnings.warn(
"This class is no longer used. Attach augmentation to "
"the LightningSystem instead.",
DeprecationWarning,
)
self.dataset = dataset
self.augmentation = augmentation
self.ds_length: int = len(dataset) # type: ignore[arg-type]
self.length = target_length if target_length is not None else self.ds_length
def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str,
torch.Tensor]]]:
item = self.dataset[index % self.ds_length]
item = self.augmentation(item)
return item
def __len__(self) -> int:
return self.length