from abc import ABC | |
from abc import abstractmethod | |
from typing import Optional | |
from typing import Tuple | |
import torch | |
class AbsEncoder(torch.nn.Module, ABC): | |
def output_size(self) -> int: | |
raise NotImplementedError | |
def forward( | |
self, | |
xs_pad: torch.Tensor, | |
ilens: torch.Tensor, | |
prev_states: torch.Tensor = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |
raise NotImplementedError | |