from abc import ABC | |
from abc import abstractmethod | |
from typing import Iterator | |
from typing import Tuple | |
from torch.utils.data import Sampler | |
class AbsSampler(Sampler, ABC): | |
def __len__(self) -> int: | |
raise NotImplementedError | |
def __iter__(self) -> Iterator[Tuple[str, ...]]: | |
raise NotImplementedError | |
def generate(self, seed): | |
return list(self) | |