|
from typing import Iterator, Optional |
|
|
|
import torch |
|
from torch.utils.data import Dataset, IterableDataset |
|
|
|
|
|
class ValidationWrapper(Dataset): |
|
"""Wraps a dataset so that PyTorch Lightning's validation step can be turned into a |
|
visualization step. |
|
""" |
|
|
|
dataset: Dataset |
|
dataset_iterator: Optional[Iterator] |
|
length: int |
|
|
|
def __init__(self, dataset: Dataset, length: int) -> None: |
|
super().__init__() |
|
self.dataset = dataset |
|
self.length = length |
|
self.dataset_iterator = None |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, index: tuple): |
|
if isinstance(self.dataset, IterableDataset): |
|
if self.dataset_iterator is None: |
|
self.dataset_iterator = iter(self.dataset) |
|
return next(self.dataset_iterator) |
|
|
|
random_index = torch.randint(0, len(self.dataset), tuple()) |
|
random_context_num = torch.randint(2, self.dataset.view_sampler.num_context_views + 1, tuple()) |
|
|
|
return self.dataset[random_index.item(), random_context_num.item()] |
|
|